├── dataset └── README.md ├── layers ├── __init__.py ├── RevIN.py └── Embed.py ├── logs └── .placeholder ├── model ├── __init__.py ├── TimesMamba.py └── mambacore.py ├── utils ├── __init__.py ├── masking.py ├── metrics.py ├── tools.py └── timefeatures.py ├── results └── .placeholder ├── checkpoints └── .placeholder ├── data_provider ├── __init__.py ├── data_factory.py └── data_loader.py ├── test_results └── .placeholder ├── figures ├── main_result.png └── architecture.png ├── .vscode ├── settings.json └── launch.json ├── run.sh ├── inspect_result.py ├── run_tuning.sh ├── scripts ├── multivariate_forecasting │ ├── ECL │ │ └── Mamba.sh │ ├── Traffic │ │ └── Mamba.sh │ └── ETT │ │ └── Mamba_ETTh1.sh └── tuning_all │ ├── SolarEnergy │ └── Mamba.sh │ ├── Exchange │ └── Mamba.sh │ ├── PEMS │ ├── Mamba_03.sh │ ├── Mamba_04.sh │ ├── Mamba_07.sh │ └── Mamba_08.sh │ ├── Weather │ └── Mamba.sh │ ├── ECL │ └── Mamba.sh │ ├── Traffic │ └── Mamba.sh │ └── ETT │ ├── Mamba_ETTh2.sh │ ├── Mamba_ETTm1.sh │ ├── Mamba_ETTm2.sh │ └── Mamba_ETTh1.sh ├── create_env.sh ├── loggingutil.py ├── experiments ├── exp_basic.py └── exp_long_term_forecasting.py ├── requirements.txt ├── LICENSE ├── .gitignore ├── README.md ├── run.py └── ThirdPartyNotices.txt /dataset/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_results/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/main_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaowen310/TimesMamba/HEAD/figures/main_result.png -------------------------------------------------------------------------------- /figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaowen310/TimesMamba/HEAD/figures/architecture.png -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "search.useIgnoreFiles": false, 3 | "files.exclude": { 4 | "**/.csv": true, 5 | "**/.log": true 6 | } 7 | } -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export model_name=TimesMamba 3 | 4 | # multivariate_forecasting 5 | 6 | bash ./scripts/multivariate_forecasting/ETT/Mamba_ETTh1.sh 7 | bash ./scripts/multivariate_forecasting/ECL/Mamba.sh 8 | bash ./scripts/multivariate_forecasting/Traffic/Mamba.sh 9 | -------------------------------------------------------------------------------- /inspect_result.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import numpy as np 4 | 5 | # %% 6 | model_dir = "./results" 7 | model_name = "" 8 | metrics_file = os.path.join(model_dir, model_name, "metrics.npy") 9 | 10 | metrics = np.load(metrics_file) 11 | indexes = [1, 0, 2, 3, 4] 12 | 13 | print(f"MAE: {metrics[0]:.6f}") 14 | print(f"MSE: {metrics[1]:.6f}") 15 | print(f"RMSE: {metrics[2]:.6f}") 16 | print(f"MAPE: {metrics[3]:.6f}") 17 | print(f"MSPE: {metrics[4]:.6f}") 18 | 19 | print("\t".join([f"{metrics[ind]:.4f}" for ind in indexes]) + "\n") 20 | 21 | # %% 22 | -------------------------------------------------------------------------------- /run_tuning.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export model_name=TimesMamba 3 | 4 | # tuning all 5 | 6 | bash ./scripts/tuning_all/ETT/Mamba_ETTh1.sh 7 | # bash ./scripts/tuning_all/ETT/Mamba_ETTh2.sh 8 | # bash ./scripts/tuning_all/ETT/Mamba_ETTm1.sh 9 | # bash ./scripts/tuning_all/ETT/Mamba_ETTm2.sh 10 | bash ./scripts/tuning_all/ECL/Mamba.sh 11 | bash ./scripts/tuning_all/Traffic/Mamba.sh 12 | # bash ./scripts/tuning_all/Weather/Mamba.sh 13 | # bash ./scripts/tuning_all/Exchange/Mamba.sh 14 | # bash ./scripts/tuning_all/SolarEnergy/Mamba.sh 15 | # bash ./scripts/tuning_all/PEMS/Mamba_03.sh 16 | # bash ./scripts/tuning_all/PEMS/Mamba_04.sh 17 | # bash ./scripts/tuning_all/PEMS/Mamba_07.sh 18 | # bash ./scripts/tuning_all/PEMS/Mamba_08.sh 19 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ECL/Mamba.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=6 2 | # model_name=Mamba 3 | 4 | pred_lens=(96 192 336 720) 5 | 6 | for pred_len in "${pred_lens[@]}" 7 | do 8 | python -u run.py \ 9 | --is_training 1 \ 10 | --root_path ./dataset/electricity/ \ 11 | --data_path electricity.csv \ 12 | --model_id ECL_96_$pred_len \ 13 | --model $model_name \ 14 | --data custom \ 15 | --features M \ 16 | --seq_len 96 \ 17 | --pred_len $pred_len \ 18 | --e_layers 1 \ 19 | --enc_in 321 \ 20 | --dec_in 321 \ 21 | --c_out 321 \ 22 | --des 'Exp' \ 23 | --d_model 256 \ 24 | --r_ff 4 \ 25 | --revin_affine \ 26 | --dropout 0.1 \ 27 | --batch_size 32 \ 28 | --learning_rate 1e-3 \ 29 | --train_epochs 10 \ 30 | --itr 1 >&1 | tee logs/ECL_${pred_len}_${model_name}.log 31 | done 32 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/Traffic/Mamba.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=6 2 | # model_name=Mamba 3 | 4 | pred_lens=(96 192 336 720) 5 | 6 | for pred_len in "${pred_lens[@]}" 7 | do 8 | python -u run.py \ 9 | --is_training 1 \ 10 | --root_path ./dataset/traffic/ \ 11 | --data_path traffic.csv \ 12 | --model_id traffic_96_$pred_len \ 13 | --model $model_name \ 14 | --data custom \ 15 | --features M \ 16 | --use_mark \ 17 | --seq_len 96 \ 18 | --pred_len $pred_len \ 19 | --e_layers 3 \ 20 | --enc_in 862 \ 21 | --dec_in 862 \ 22 | --c_out 862 \ 23 | --des 'Exp' \ 24 | --d_model 512 \ 25 | --r_ff 2 \ 26 | --dropout 0.1 \ 27 | --batch_size 32 \ 28 | --learning_rate 1e-3 \ 29 | --train_epochs 10 \ 30 | --itr 1 >&1 | tee logs/Traffic_${pred_len}_${model_name}.log 31 | done 32 | -------------------------------------------------------------------------------- /create_env.sh: -------------------------------------------------------------------------------- 1 | # Create the conda environment manually, and then run the script. 2 | # 3 | # conda create -n timesmamba python=3.11 -y 4 | # conda activate timesmamba 5 | 6 | # torch version 2.1.x and CUDA version 12.1 are required for causal-conv1d==1.1.0 7 | pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121 8 | pip install "numpy<2" 9 | # pip version < 25.3 10 | pip install "causal-conv1d==1.1.0" 11 | # pip version < 25.3 12 | pip install "mamba-ssm==1.1.0" 13 | pip install packaging 14 | pip install pandas 15 | pip install scikit-learn 16 | pip install matplotlib 17 | 18 | # reformer-pytorch uses a newer version of pytorch, and the pytorch's cuda version 19 | # is incompatible with causal-conv1d==1.1.0. 20 | # Please create a separate environment for reformer. 21 | # 22 | # pip install reformer-pytorch 23 | -------------------------------------------------------------------------------- /loggingutil.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from logging.handlers import TimedRotatingFileHandler 4 | 5 | FORMATTER = logging.Formatter("%(asctime)s — %(name)s — %(levelname)s — %(message)s") 6 | LOG_FILE = "my_app.log" 7 | 8 | 9 | def get_console_handler(): 10 | console_handler = logging.StreamHandler(sys.stdout) 11 | console_handler.setFormatter(FORMATTER) 12 | return console_handler 13 | 14 | 15 | def get_file_handler(): 16 | file_handler = TimedRotatingFileHandler(LOG_FILE, when="midnight") 17 | file_handler.setFormatter(FORMATTER) 18 | return file_handler 19 | 20 | 21 | def get_logger(logger_name): 22 | logger = logging.getLogger(logger_name) 23 | logger.setLevel(logging.DEBUG) 24 | logger.addHandler(get_console_handler()) 25 | logger.propagate = False 26 | return logger 27 | -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | def MAPE(pred, true): 27 | return np.mean(np.abs((pred - true) / true)) 28 | 29 | 30 | def MSPE(pred, true): 31 | return np.mean(np.square((pred - true) / true)) 32 | 33 | 34 | def metric(pred, true): 35 | mae = MAE(pred, true) 36 | mse = MSE(pred, true) 37 | rmse = RMSE(pred, true) 38 | mape = MAPE(pred, true) 39 | mspe = MSPE(pred, true) 40 | 41 | return mae, mse, rmse, mape, mspe 42 | -------------------------------------------------------------------------------- /experiments/exp_basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import loggingutil 4 | from model import ( 5 | TimesMamba, 6 | ) 7 | 8 | logger = loggingutil.get_logger(__name__) 9 | 10 | 11 | class Exp_Basic(object): 12 | def __init__(self, args): 13 | self.args = args 14 | self.model_dict = { 15 | "TimesMamba": TimesMamba, 16 | } 17 | self.device = self._acquire_device() 18 | self.model = self._build_model().to(self.device) 19 | 20 | def _build_model(self): 21 | raise NotImplementedError 22 | 23 | def _acquire_device(self): 24 | if self.args.use_gpu: 25 | device = torch.device("cuda:{}".format(self.args.device)) 26 | print("Use GPU: cuda:{}".format(self.args.device)) 27 | else: 28 | device = torch.device("cpu") 29 | print("Use CPU") 30 | return device 31 | 32 | def _get_data(self): 33 | pass 34 | 35 | def vali(self): 36 | pass 37 | 38 | def train(self): 39 | pass 40 | 41 | def test(self): 42 | pass 43 | -------------------------------------------------------------------------------- /scripts/tuning_all/SolarEnergy/Mamba.sh: -------------------------------------------------------------------------------- 1 | # model_name=STMamba 2 | 3 | pred_lens=(96 192 336 720) 4 | lrs=(1e-4) 5 | d_models=(256) 6 | e_layerss=(2) 7 | 8 | for pred_len in "${pred_lens[@]}"; do 9 | for e_layers in "${e_layerss[@]}"; do 10 | for d_model in "${d_models[@]}"; do 11 | for lr in "${lrs[@]}"; do 12 | python -u run.py \ 13 | --is_training 1 \ 14 | --root_path ./dataset/Solar/ \ 15 | --data_path solar_AL.txt \ 16 | --model_id solar_96_$pred_len \ 17 | --model $model_name \ 18 | --data Solar \ 19 | --features M \ 20 | --seq_len 96 \ 21 | --pred_len $pred_len \ 22 | --e_layers $e_layers \ 23 | --enc_in 137 \ 24 | --dec_in 137 \ 25 | --c_out 137 \ 26 | --des 'Exp' \ 27 | --d_model $d_model \ 28 | --r_ff 4 \ 29 | --dropout 0.1 \ 30 | --batch_size 32 \ 31 | --learning_rate $lr \ 32 | --train_epochs 10 \ 33 | --itr 1 >&1 | tee SolarE_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_tol1e-3_lr$lr.log 34 | done 35 | done 36 | done 37 | done 38 | -------------------------------------------------------------------------------- /scripts/tuning_all/Exchange/Mamba.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(1e-3) 6 | d_models=(32) 7 | e_layerss=(1) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/exchange_rate/ \ 16 | --data_path exchange_rate.csv \ 17 | --model_id Exchange_96_$pred_len \ 18 | --model $model_name \ 19 | --data custom \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 8 \ 25 | --dec_in 8 \ 26 | --c_out 8 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --batch_size 32 \ 31 | --learning_rate $lr \ 32 | --train_epochs 10 \ 33 | --itr 1 >&1 | tee Exchange_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 34 | done 35 | done 36 | done 37 | done 38 | -------------------------------------------------------------------------------- /scripts/tuning_all/PEMS/Mamba_03.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(12 24 48 96) 5 | lrs=(1e-3 5e-4 1e-4) 6 | d_models=(512) 7 | e_layerss=(3) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/PEMS/ \ 16 | --data_path PEMS03.npz \ 17 | --model_id PEMS03_96_$pred_len \ 18 | --model $model_name \ 19 | --data PEMS \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 358 \ 25 | --dec_in 358 \ 26 | --c_out 358 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --learning_rate $lr \ 31 | --train_epochs 10 \ 32 | --no_norm \ 33 | --batch_size 32 \ 34 | --itr 1 >&1 | tee PEMS03_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 35 | done 36 | done 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /scripts/tuning_all/PEMS/Mamba_04.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(12 24 48 96) 5 | lrs=(1e-3 5e-4 1e-4) 6 | d_models=(512) 7 | e_layerss=(3) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/PEMS/ \ 16 | --data_path PEMS04.npz \ 17 | --model_id PEMS04_96_$pred_len \ 18 | --model $model_name \ 19 | --data PEMS \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 307 \ 25 | --dec_in 307 \ 26 | --c_out 307 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --learning_rate $lr \ 31 | --train_epochs 10 \ 32 | --no_norm \ 33 | --batch_size 32 \ 34 | --itr 1 >&1 | tee PEMS04_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 35 | done 36 | done 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /scripts/tuning_all/PEMS/Mamba_07.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(12 24 48 96) 5 | lrs=(1e-3 5e-4 1e-4) 6 | d_models=(512) 7 | e_layerss=(3) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/PEMS/ \ 16 | --data_path PEMS07.npz \ 17 | --model_id PEMS07_96_$pred_len \ 18 | --model $model_name \ 19 | --data PEMS \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 883 \ 25 | --dec_in 883 \ 26 | --c_out 883 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --learning_rate $lr \ 31 | --train_epochs 10 \ 32 | --no_norm \ 33 | --batch_size 32 \ 34 | --itr 1 >&1 | tee PEMS07_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 35 | done 36 | done 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /scripts/tuning_all/Weather/Mamba.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=6 2 | # model_name=iBMamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(1e-4 5e-5) 6 | d_models=(256) 7 | e_layerss=(3) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/weather/ \ 16 | --data_path weather.csv \ 17 | --model_id weather_96_$pred_len \ 18 | --model $model_name \ 19 | --data custom \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 21 \ 25 | --dec_in 21 \ 26 | --c_out 21 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --dropout 0.1 \ 31 | --batch_size 32 \ 32 | --learning_rate $lr \ 33 | --train_epochs 10 \ 34 | --itr 1 >&1 | tee Weather_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 35 | done 36 | done 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /scripts/tuning_all/PEMS/Mamba_08.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(12 24 48 96) 5 | use_norms=(1 1 0 0) 6 | lrs=(1e-3 5e-4 1e-4) 7 | d_models=(512) 8 | e_layerss=(3) 9 | 10 | for ((i = 0; i <= ${#pred_lens[@]}; i++)); do 11 | for e_layers in "${e_layerss[@]}"; do 12 | for d_model in "${d_models[@]}"; do 13 | for lr in "${lrs[@]}"; do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path ./dataset/PEMS/ \ 17 | --data_path PEMS08.npz \ 18 | --model_id PEMS08_96_${pred_lens[i]} \ 19 | --model $model_name \ 20 | --data PEMS \ 21 | --features M \ 22 | --seq_len 96 \ 23 | --pred_len ${pred_lens[i]} \ 24 | --e_layers $e_layers \ 25 | --enc_in 170 \ 26 | --dec_in 170 \ 27 | --c_out 170 \ 28 | --des 'Exp' \ 29 | --d_model $d_model \ 30 | --r_ff 4 \ 31 | --learning_rate $lr \ 32 | --train_epochs 10 \ 33 | --batch_size 32 \ 34 | --itr 1 >&1 | tee PEMS08_${pred_lens[i]}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 35 | done 36 | done 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /scripts/tuning_all/ECL/Mamba.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=Mamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(1e-3) 6 | d_models=(256) 7 | e_layerss=(1) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/electricity/ \ 16 | --data_path electricity.csv \ 17 | --model_id ECL_96_$pred_len \ 18 | --model $model_name \ 19 | --data custom \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 321 \ 25 | --dec_in 321 \ 26 | --c_out 321 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --revin_affine \ 31 | --dropout 0.1 \ 32 | --batch_size 32 \ 33 | --learning_rate $lr \ 34 | --train_epochs 10 \ 35 | --itr 1 >&1 | tee ECL_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_rff4_lr$lr.log 36 | done 37 | done 38 | done 39 | done 40 | -------------------------------------------------------------------------------- /scripts/tuning_all/Traffic/Mamba.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=6 2 | # model_name=iBMamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(1e-3) 6 | d_models=(512) 7 | e_layerss=(3) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/traffic/ \ 16 | --data_path traffic.csv \ 17 | --model_id traffic_96_$pred_len \ 18 | --model $model_name \ 19 | --data custom \ 20 | --features M \ 21 | --use_mark \ 22 | --seq_len 96 \ 23 | --pred_len $pred_len \ 24 | --e_layers $e_layers \ 25 | --enc_in 862 \ 26 | --dec_in 862 \ 27 | --c_out 862 \ 28 | --des 'Exp' \ 29 | --d_model $d_model \ 30 | --r_ff 2 \ 31 | --dropout 0.1 \ 32 | --batch_size 32 \ 33 | --learning_rate $lr \ 34 | --train_epochs 10 \ 35 | --itr 1 >&1 | tee Traffic_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_rff2_afff_lr$lr.log 36 | done 37 | done 38 | done 39 | done 40 | -------------------------------------------------------------------------------- /scripts/tuning_all/ETT/Mamba_ETTh2.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(1e-3) 6 | d_models=(32) 7 | e_layerss=(1) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/ETT-small/ \ 16 | --data_path ETTh2.csv \ 17 | --model_id ETTh2_96_$pred_len \ 18 | --model $model_name \ 19 | --data ETTh2 \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 7 \ 25 | --dec_in 7 \ 26 | --c_out 7 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --revin_affine \ 31 | --channel_independence \ 32 | --ssm_expand 0 \ 33 | --batch_size 32 \ 34 | --learning_rate $lr \ 35 | --train_epochs 10 \ 36 | --itr 1 >&1 | tee ETTh2_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 37 | done 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /scripts/tuning_all/ETT/Mamba_ETTm1.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(1e-3) 6 | d_models=(32) 7 | e_layerss=(1) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/ETT-small/ \ 16 | --data_path ETTm1.csv \ 17 | --model_id ETTm1_96_$pred_len \ 18 | --model $model_name \ 19 | --data ETTm1 \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 7 \ 25 | --dec_in 7 \ 26 | --c_out 7 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --revin_affine \ 31 | --channel_independence \ 32 | --ssm_expand 0 \ 33 | --batch_size 32 \ 34 | --learning_rate $lr \ 35 | --train_epochs 10 \ 36 | --itr 1 >&1 | tee ETTm1_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 37 | done 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /scripts/tuning_all/ETT/Mamba_ETTm2.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=STMamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(1e-3) 6 | d_models=(32) 7 | e_layerss=(1) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/ETT-small/ \ 16 | --data_path ETTm2.csv \ 17 | --model_id ETTm2_96_$pred_len \ 18 | --model $model_name \ 19 | --data ETTm2 \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 7 \ 25 | --dec_in 7 \ 26 | --c_out 7 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --revin_affine \ 31 | --channel_independence \ 32 | --ssm_expand 0 \ 33 | --batch_size 32 \ 34 | --learning_rate $lr \ 35 | --train_epochs 10 \ 36 | --itr 1 >&1 | tee ETTm2_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_lr$lr.log 37 | done 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==25.3.0 2 | Automat==25.4.16 3 | buildtools==1.0.6 4 | causal-conv1d==1.1.0 5 | certifi==2025.4.26 6 | charset-normalizer==3.4.1 7 | constantly==23.10.4 8 | contourpy==1.3.2 9 | cycler==0.12.1 10 | docopt==0.6.2 11 | einops==0.8.1 12 | filelock==3.13.1 13 | fonttools==4.57.0 14 | fsspec==2024.6.1 15 | furl==2.1.4 16 | greenlet==3.2.1 17 | huggingface-hub==0.30.2 18 | hyperlink==21.0.0 19 | idna==3.10 20 | incremental==24.7.2 21 | Jinja2==3.1.4 22 | joblib==1.4.2 23 | kiwisolver==1.4.8 24 | mamba-ssm==1.1.0 25 | MarkupSafe==2.1.5 26 | matplotlib==3.10.1 27 | mpmath==1.3.0 28 | networkx==3.3 29 | ninja==1.11.1.4 30 | numpy==1.26.4 31 | orderedmultidict==1.0.1 32 | packaging==25.0 33 | pandas==2.2.3 34 | pillow==11.2.1 35 | pyparsing==3.2.3 36 | python-dateutil==2.9.0.post0 37 | pytz==2025.2 38 | PyYAML==6.0.2 39 | redo==3.0.0 40 | regex==2024.11.6 41 | requests==2.32.3 42 | safetensors==0.5.3 43 | scikit-learn==1.6.1 44 | scipy==1.15.2 45 | simplejson==3.20.1 46 | six==1.17.0 47 | SQLAlchemy==2.0.40 48 | sympy==1.13.3 49 | threadpoolctl==3.6.0 50 | tokenizers==0.21.1 51 | torch==2.1.2+cu121 52 | tqdm==4.67.1 53 | transformers==4.51.3 54 | triton==2.1.0 55 | Twisted==24.11.0 56 | typing_extensions==4.12.2 57 | tzdata==2025.2 58 | urllib3==2.4.0 59 | zope.interface==7.2 60 | -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import ( 2 | Dataset_ETT_hour, 3 | Dataset_ETT_minute, 4 | Dataset_Custom, 5 | Dataset_Solar, 6 | Dataset_PEMS, 7 | Dataset_Pred, 8 | ) 9 | from torch.utils.data import DataLoader 10 | 11 | data_dict = { 12 | "ETTh1": Dataset_ETT_hour, 13 | "ETTh2": Dataset_ETT_hour, 14 | "ETTm1": Dataset_ETT_minute, 15 | "ETTm2": Dataset_ETT_minute, 16 | "Solar": Dataset_Solar, 17 | "PEMS": Dataset_PEMS, 18 | "custom": Dataset_Custom, 19 | } 20 | 21 | 22 | def data_provider(args, flag): 23 | Data = Dataset_Pred if flag == "pred" else data_dict[args.data] 24 | 25 | data_set = Data( 26 | root_path=args.root_path, 27 | data_path=args.data_path, 28 | flag=flag, 29 | size=[args.seq_len, args.label_len, args.pred_len], 30 | features=args.features, 31 | target=args.target, 32 | timeenc=1 if args.embed == "timeF" else 0, 33 | freq=args.freq, 34 | ) 35 | print(flag, len(data_set)) 36 | 37 | data_loader = DataLoader( 38 | data_set, 39 | batch_size=args.batch_size, 40 | shuffle=True if flag == "train" else False, 41 | num_workers=args.num_workers, 42 | ) 43 | 44 | return data_set, data_loader 45 | -------------------------------------------------------------------------------- /scripts/tuning_all/ETT/Mamba_ETTh1.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=Mamba 3 | 4 | pred_lens=(96 192 336 720) 5 | lrs=(5e-3) 6 | d_models=(32) 7 | e_layerss=(1) 8 | 9 | for pred_len in "${pred_lens[@]}"; do 10 | for e_layers in "${e_layerss[@]}"; do 11 | for d_model in "${d_models[@]}"; do 12 | for lr in "${lrs[@]}"; do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path ./dataset/ETT-small/ \ 16 | --data_path ETTh1.csv \ 17 | --model_id ETTh1_96_$pred_len \ 18 | --model $model_name \ 19 | --data ETTh1 \ 20 | --features M \ 21 | --seq_len 96 \ 22 | --pred_len $pred_len \ 23 | --e_layers $e_layers \ 24 | --enc_in 7 \ 25 | --dec_in 7 \ 26 | --c_out 7 \ 27 | --des 'Exp' \ 28 | --d_model $d_model \ 29 | --r_ff 4 \ 30 | --revin_affine \ 31 | --channel_independence \ 32 | --ssm_expand 0 \ 33 | --dropout 0.1 \ 34 | --batch_size 32 \ 35 | --learning_rate $lr \ 36 | --train_epochs 10 \ 37 | --itr 1 >&1 | tee ETTh1_${pred_len}_${model_name}_el${e_layers}_dm${d_model}_rff4_lr$lr.log 38 | done 39 | done 40 | done 41 | done 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2024 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | 15 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. 16 | -------------------------------------------------------------------------------- /layers/RevIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RevIN(nn.Module): 6 | def __init__(self, num_features: int, eps=1e-5, affine=True): 7 | """ 8 | :param num_features: the number of features or channels 9 | :param eps: a value added for numerical stability 10 | :param affine: if True, RevIN has learnable affine parameters 11 | """ 12 | super(RevIN, self).__init__() 13 | self.num_features = num_features 14 | self.eps = eps 15 | self.affine = affine 16 | if self.affine: 17 | self._init_params() 18 | 19 | def forward(self, x, mode: str): 20 | if mode == "norm": 21 | self._get_statistics(x) 22 | x = self._normalize(x) 23 | elif mode == "denorm": 24 | x = self._denormalize(x) 25 | else: 26 | raise NotImplementedError 27 | return x 28 | 29 | def _init_params(self): 30 | # initialize RevIN params: (C,) 31 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 32 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 33 | 34 | def _get_statistics(self, x): 35 | dim2reduce = tuple(range(1, x.ndim - 1)) 36 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 37 | self.stdev = torch.sqrt( 38 | torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps 39 | ).detach() 40 | 41 | def _normalize(self, x): 42 | x = x - self.mean 43 | x = x / self.stdev 44 | if self.affine: 45 | x = x * self.affine_weight 46 | x = x + self.affine_bias 47 | return x 48 | 49 | def _denormalize(self, x): 50 | if self.affine: 51 | x = x - self.affine_bias 52 | x = x / (self.affine_weight + self.eps * self.eps) 53 | x = x * self.stdev 54 | x = x + self.mean 55 | return x 56 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/Mamba_ETTh1.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=0 2 | # model_name=Mamba 3 | 4 | python -u run.py \ 5 | --is_training 1 \ 6 | --root_path ./dataset/ETT-small/ \ 7 | --data_path ETTh1.csv \ 8 | --model_id ETTh1_96_96 \ 9 | --model $model_name \ 10 | --data ETTh1 \ 11 | --features M \ 12 | --seq_len 96 \ 13 | --pred_len 96 \ 14 | --e_layers 1 \ 15 | --enc_in 7 \ 16 | --dec_in 7 \ 17 | --c_out 7 \ 18 | --des 'Exp' \ 19 | --d_model 32 \ 20 | --r_ff 4 \ 21 | --revin_affine \ 22 | --channel_independence \ 23 | --dropout 0.1 \ 24 | --batch_size 32 \ 25 | --learning_rate 5e-3 \ 26 | --train_epochs 10 \ 27 | --itr 1 >&1 | tee logs/ETTh1_96_${model_name}.log 28 | 29 | python -u run.py \ 30 | --is_training 1 \ 31 | --root_path ./dataset/ETT-small/ \ 32 | --data_path ETTh1.csv \ 33 | --model_id ETTh1_96_192 \ 34 | --model $model_name \ 35 | --data ETTh1 \ 36 | --features M \ 37 | --seq_len 96 \ 38 | --pred_len 192 \ 39 | --e_layers 1 \ 40 | --enc_in 7 \ 41 | --dec_in 7 \ 42 | --c_out 7 \ 43 | --des 'Exp' \ 44 | --d_model 32 \ 45 | --r_ff 4 \ 46 | --revin_affine \ 47 | --channel_independence \ 48 | --dropout 0.1 \ 49 | --batch_size 32 \ 50 | --learning_rate 1e-3 \ 51 | --train_epochs 10 \ 52 | --itr 1 >&1 | tee logs/ETTh1_192_${model_name}.log 53 | 54 | python -u run.py \ 55 | --is_training 1 \ 56 | --root_path ./dataset/ETT-small/ \ 57 | --data_path ETTh1.csv \ 58 | --model_id ETTh1_96_336 \ 59 | --model $model_name \ 60 | --data ETTh1 \ 61 | --features M \ 62 | --seq_len 96 \ 63 | --pred_len 336 \ 64 | --e_layers 1 \ 65 | --enc_in 7 \ 66 | --dec_in 7 \ 67 | --c_out 7 \ 68 | --des 'Exp' \ 69 | --d_model 32 \ 70 | --r_ff 4 \ 71 | --revin_affine \ 72 | --channel_independence \ 73 | --dropout 0.1 \ 74 | --batch_size 32 \ 75 | --learning_rate 5e-4 \ 76 | --train_epochs 10 \ 77 | --itr 1 >&1 | tee logs/ETTh1_336_${model_name}.log 78 | 79 | python -u run.py \ 80 | --is_training 1 \ 81 | --root_path ./dataset/ETT-small/ \ 82 | --data_path ETTh1.csv \ 83 | --model_id ETTh1_96_720 \ 84 | --model $model_name \ 85 | --data ETTh1 \ 86 | --features M \ 87 | --seq_len 96 \ 88 | --pred_len 720 \ 89 | --e_layers 1 \ 90 | --enc_in 7 \ 91 | --dec_in 7 \ 92 | --c_out 7 \ 93 | --des 'Exp' \ 94 | --d_model 64 \ 95 | --r_ff 4 \ 96 | --revin_affine \ 97 | --channel_independence \ 98 | --dropout 0.1 \ 99 | --batch_size 32 \ 100 | --learning_rate 5e-4 \ 101 | --train_epochs 10 \ 102 | --itr 1 >&1 | tee logs/ETTh1_720_${model_name}.log 103 | -------------------------------------------------------------------------------- /model/TimesMamba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from layers.Embed import SeriesEmbedding 5 | from layers.RevIN import RevIN 6 | from model.mambacore import MambaForSeriesForecasting 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, config): 11 | super(Model, self).__init__() 12 | self.seq_len = config.seq_len 13 | self.pred_len = config.pred_len 14 | self.use_norm = config.use_norm 15 | self.use_mark = config.use_mark 16 | self.channel_independence = config.channel_independence 17 | self.d_model = config.d_model 18 | 19 | if self.use_norm: 20 | self.revin_layer = RevIN(config.enc_in, affine=config.revin_affine) 21 | 22 | # Embedding 23 | self.enc_embedding = SeriesEmbedding( 24 | config.seq_len, 25 | config.d_model, 26 | config.dropout, 27 | ) 28 | print(self.enc_embedding) 29 | 30 | # Encoder-only architecture 31 | self.mamba = MambaForSeriesForecasting( 32 | dims=[config.d_model], 33 | depths=[config.e_layers], 34 | ssm_expand=config.ssm_expand, 35 | ssm_drop_rate=config.dropout, 36 | mlp_ratio=config.r_ff, 37 | mlp_drop_rate=config.dropout, 38 | drop_path_rate=config.dropout, 39 | ) 40 | print(self.mamba) 41 | 42 | self.projector = nn.Linear( 43 | config.d_model, 44 | config.pred_len, 45 | bias=True, 46 | ) 47 | 48 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 49 | B, _, N = x_enc.shape # b ts nv 50 | 51 | if self.use_norm: 52 | x_enc = self.revin_layer(x_enc, "norm") 53 | 54 | if not self.use_mark: 55 | x_mark_enc = None 56 | 57 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # b nv c 58 | 59 | if self.channel_independence: 60 | enc_out = enc_out.reshape((-1, 1, self.d_model)) # b*nv 1 c 61 | 62 | enc_out = torch.unsqueeze(enc_out, 1) # b 1 nv c 63 | enc_out = self.mamba(enc_out) # b 1 nv c 64 | enc_out = torch.squeeze(enc_out, 1) # b nv c 65 | 66 | if self.channel_independence: 67 | enc_out = enc_out.reshape((B, -1, self.d_model)) # b nv c 68 | 69 | enc_out = self.projector(enc_out).transpose(1, 2)[:, :, :N] # b ts nv 70 | 71 | if self.use_norm: 72 | enc_out = self.revin_layer(enc_out, "denorm") 73 | 74 | return enc_out 75 | 76 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 77 | enc_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 78 | return enc_out[:, -self.pred_len :, :] 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | 164 | dataset/ 165 | !dataset/README.md 166 | checkpoints 167 | !checkpoints/.placeholder 168 | results 169 | !results/.placeholder 170 | test_results 171 | !test_results/.placeholder 172 | result_*.txt 173 | my_app.log* 174 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import optim 4 | import matplotlib.pyplot as plt 5 | 6 | plt.switch_backend("agg") 7 | 8 | 9 | def get_lr_scheduler(optimizer, train_epochs, warmup_epochs=0, lradj="type1"): 10 | schedulers = [] 11 | lr_sched_milestones = [] 12 | 13 | if warmup_epochs > 0: 14 | warmup_fn = lambda c: 1 / (10 ** (float(warmup_epochs - c))) 15 | warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn) 16 | schedulers.append(warmup_scheduler) 17 | lr_sched_milestones.append(warmup_epochs) 18 | 19 | if lradj == "type1": 20 | cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( 21 | optimizer, train_epochs - warmup_epochs, eta_min=1e-5 22 | ) 23 | schedulers.append(cosine_scheduler) 24 | elif lradj == "type2": 25 | exp_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.5) 26 | schedulers.append(exp_scheduler) 27 | else: 28 | ms_scheduler = optim.lr_scheduler.MultiStepLR( 29 | optimizer, milestones=[2, 4, 6, 8, 10, 15, 20], gamma=0.5 30 | ) 31 | schedulers.append(ms_scheduler) 32 | 33 | return optim.lr_scheduler.SequentialLR(optimizer, schedulers, lr_sched_milestones) 34 | 35 | 36 | class EarlyStopping: 37 | def __init__(self, patience=3, verbose=False, delta=0): 38 | self.patience = patience 39 | self.verbose = verbose 40 | self.counter = 0 41 | self.best_score = None 42 | self.early_stop = False 43 | self.val_loss_min = np.Inf 44 | self.delta = delta 45 | 46 | def __call__(self, val_loss, model, path): 47 | score = val_loss 48 | if self.best_score is None: 49 | self.best_score = score 50 | self.save_checkpoint(val_loss, model, path) 51 | elif score > self.best_score - self.delta: 52 | self.counter += 1 53 | print(f"EarlyStopping counter: {self.counter} out of {self.patience}") 54 | if self.patience >= 0 and self.counter >= self.patience: 55 | self.early_stop = True 56 | else: 57 | self.best_score = min(score, self.best_score) 58 | self.save_checkpoint(val_loss, model, path) 59 | self.counter = 0 60 | 61 | def save_checkpoint(self, val_loss, model, path): 62 | if self.verbose: 63 | print( 64 | f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." 65 | ) 66 | torch.save(model.state_dict(), path + "/" + "checkpoint.pth") 67 | self.val_loss_min = min(val_loss, self.val_loss_min) 68 | 69 | 70 | class dotdict(dict): 71 | """dot.notation access to dictionary attributes""" 72 | 73 | __getattr__ = dict.get 74 | __setattr__ = dict.__setitem__ 75 | __delattr__ = dict.__delitem__ 76 | 77 | 78 | class StandardScaler: 79 | def __init__(self, mean, std): 80 | self.mean = mean 81 | self.std = std 82 | 83 | def transform(self, data): 84 | return (data - self.mean) / self.std 85 | 86 | def inverse_transform(self, data): 87 | return (data * self.std) + self.mean 88 | 89 | 90 | def visual(true, preds=None, name="./pic/test.pdf"): 91 | """ 92 | Results visualization 93 | """ 94 | plt.figure() 95 | plt.plot(true, label="GroundTruth", linewidth=2) 96 | if preds is not None: 97 | plt.plot(preds, label="Prediction", linewidth=2) 98 | plt.legend() 99 | plt.savefig(name, bbox_inches="tight") 100 | 101 | 102 | def adjustment(gt, pred): 103 | anomaly_state = False 104 | for i in range(len(gt)): 105 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 106 | anomaly_state = True 107 | for j in range(i, 0, -1): 108 | if gt[j] == 0: 109 | break 110 | else: 111 | if pred[j] == 0: 112 | pred[j] = 1 113 | for j in range(i, len(gt)): 114 | if gt[j] == 0: 115 | break 116 | else: 117 | if pred[j] == 0: 118 | pred[j] = 1 119 | elif gt[i] == 0: 120 | anomaly_state = False 121 | if anomaly_state: 122 | pred[i] = 1 123 | return gt, pred 124 | 125 | 126 | def cal_accuracy(y_pred, y_true): 127 | return np.mean(y_pred == y_true) 128 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mamba for Multivariate Time Series Forecasting 2 | 3 | The repo implements a Mamba-empowered model for multivariate time series forecasting - TimesMamba. 4 | 5 | 6 | 7 | ## Introduction 8 | 9 | 🌟 We implement a Mamba-based model for multi-variate time series forecasting. We use series embedding and let Mamba learn the correlations among the series. 10 | 11 | 🌟 Without attention, Mamba achieves the same modeling power as Transformer while consuming much less VRAM. 12 | 13 | 🌟 We conduct experiments demonstrating that Mamba performs better on multiple long-term time series forecasting benchmarks. 14 | 15 | ## Overall Architecture 16 | 17 | TimesMamba consists of the following components: 18 | 19 | - A reversible instance normalization layer to normalize individual series. 20 | - A linear embedding layer that projects individual series into the embedding space. 21 | - Mamba blocks that capture the correlations among the series. 22 | 23 | A Mamba block contains two major components: a bidirectional Mamba and a FFN. The Mamba module replaces the self-attention layer in a typical Transformer encoder. `Scan+` and `Scan-` in the diagram indicate the normal and reverse order scanning. 24 | 25 | ![image](./figures/architecture.png) 26 | 27 | We regard the series embedding as a special case of patching where the patch length equals the series length. We consider both channel-mixing and channel-independent modes. 28 | 29 | ## Datasets 30 | 31 | The datasets can be obtained by following the instructions in this [repo](https://github.com/thuml/iTransformer). 32 | 33 | ## Usage 34 | 35 | ### Environment setup 36 | 37 | 1. Install miniconda. 38 | 39 | 2. Create a conda environment and install Pytorch and necessary dependencies. The environment is named as `timesmamba`. You may run the following bash script. 40 | 41 | ```bash 42 | # Tested to work on Linux, but not on WSL. 43 | bash create_env.sh 44 | ``` 45 | 46 | ### Model training and evaluation 47 | 48 | Train and evaluate the model. We provide the scrips under the folder ./scripts/. You can reproduce the results as the following examples: 49 | 50 | ```bash 51 | # Multivariate forecasting with TimesMamba 52 | bash run.sh 53 | 54 | # Tuning parameters 55 | bash run_tuning.sh 56 | ``` 57 | 58 | ## Main Result of Multivariate Forecasting 59 | 60 | We evaluate the TimesMamba on ETTh1, Electricity (ECL), and Traffic. These three datasets have diverse characteristics. 61 | 62 | We set the lookback window to 96. We use channel independence for ETTh1 and channel mixing for others. We train the model for ten epochs and select the model having the best validation loss within these ten epochs. 63 | 64 | The main results are shown in the following table. The results of iTransformer and PatchTST are extracted from the iTransformer paper. 65 | 66 | ![image](./figures/main_result.png) 67 | 68 | ## Findings 69 | 70 | ### What role does Mamba play? 71 | 72 | We find that Mamba is a good substitute for the attention module in Transformer. It is also more effective at capturing long-term dependencies while enjoying linear growth in terms of computation and VRAM cost. You may turn off the Mamba module by supplying `--ssm_expand 0`. 73 | 74 | ### What role does FFN play? 75 | 76 | We conduct experiments using ensemble of Mamba modules without FFN and the single Mamba module with FFN. Experiments show that the FFN layer functions similar as ensembling multiple Mamba modules while pocessing much less VRAM. You may turn off the FFN layer by supplying `--r_ff 0`. 77 | 78 | ### Channel independence vs Channel mixing 79 | 80 | We find that channel independence may be more effective for series having low variate correlation, while channel mixing may be more effective for the opposite case. This finding is consistent with the paper [TFB](https://arxiv.org/abs/2403.20150). However, the Mamba module is ineffective in channel-independent mode since the sequence length is only one, and the Mamba module cannot capture any long-term dependencies. You may use the channel independence mode by supplying `--channel_independence`. 81 | 82 | ### Reversible instance normalization 83 | 84 | [RevIN](https://github.com/ts-kim/RevIN) focuses on solving the distribution shift problem. We find this trick works well for most datasets. Learnable affine transformation may boost the performance further for some datasets. However, RevIN does not work well for the (subset of) PEMS dataset. You may turn off RevIN using `--no_norm`. You may turn on learnable affine transformation using `--revin_affine`. 85 | 86 | ### Effective of temporal features 87 | 88 | The temporal features are timestamps that indicate the month of year, day of week, hour of day, etc. We append these features to the series as additional variables and find that they are especially effective for the Traffic dataset. However, they may harm the performance of some datasets. We find that series with high seasonality benefit from such features. You may turn these features on using `--use_mark`. 89 | 90 | 91 | 92 | ## Acknowledgement 93 | 94 | We greatly appreciate the following GitHub repos for their valuable code and efforts. 95 | 96 | - RevIN (https://github.com/ts-kim/RevIN) 97 | - iTransformer () 98 | - Mamba () 99 | - VMamba () 100 | 101 | I also want to thank Qianxiong Xu, Chenxi Liu, Ziyue Li and Cheng Long for their valuable advice. 102 | 103 | ## Contact 104 | 105 | If you have any questions or want to use the code, feel free to contact: [shaowen.zhou@ntu.edu.sg](mailto:shaowen.zhou@ntu.edu.sg) 106 | -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, d_model, max_len=5000): 8 | super(PositionalEmbedding, self).__init__() 9 | # Compute the positional encodings once in log space. 10 | pe = torch.zeros(max_len, d_model).float() 11 | pe.require_grad = False 12 | 13 | position = torch.arange(0, max_len).float().unsqueeze(1) 14 | div_term = ( 15 | torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) 16 | ).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer("pe", pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, : x.size(1)] 26 | 27 | 28 | class TokenEmbedding(nn.Module): 29 | def __init__(self, c_in, d_model): 30 | super(TokenEmbedding, self).__init__() 31 | self.token_conv = nn.Conv1d( 32 | c_in, 33 | d_model, 34 | kernel_size=3, 35 | padding=1, 36 | padding_mode="circular", 37 | ) 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv1d): 40 | nn.init.kaiming_normal_( 41 | m.weight, mode="fan_in", nonlinearity="leaky_relu" 42 | ) 43 | 44 | def forward(self, x): 45 | return self.token_conv(x.transpose(1, 2)).transpose(1, 2) 46 | 47 | 48 | class FixedEmbedding(nn.Module): 49 | def __init__(self, c_in, d_model): 50 | super(FixedEmbedding, self).__init__() 51 | 52 | w = torch.zeros(c_in, d_model).float() 53 | w.require_grad = False 54 | 55 | position = torch.arange(0, c_in).float().unsqueeze(1) 56 | div_term = ( 57 | torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) 58 | ).exp() 59 | 60 | w[:, 0::2] = torch.sin(position * div_term) 61 | w[:, 1::2] = torch.cos(position * div_term) 62 | 63 | self.emb = nn.Embedding(c_in, d_model) 64 | self.emb.weight = nn.Parameter(w, requires_grad=False) 65 | 66 | def forward(self, x): 67 | return self.emb(x).detach() 68 | 69 | 70 | class TemporalEmbedding(nn.Module): 71 | def __init__(self, d_model, embed_type="fixed", freq="h"): 72 | super(TemporalEmbedding, self).__init__() 73 | 74 | minute_size = 4 75 | hour_size = 24 76 | weekday_size = 7 77 | day_size = 32 78 | month_size = 13 79 | 80 | Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding 81 | if freq == "t": 82 | self.minute_embed = Embed(minute_size, d_model) 83 | self.hour_embed = Embed(hour_size, d_model) 84 | self.weekday_embed = Embed(weekday_size, d_model) 85 | self.day_embed = Embed(day_size, d_model) 86 | self.month_embed = Embed(month_size, d_model) 87 | 88 | def forward(self, x): 89 | x = x.long() 90 | minute_x = ( 91 | self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 92 | ) 93 | hour_x = self.hour_embed(x[:, :, 3]) 94 | weekday_x = self.weekday_embed(x[:, :, 2]) 95 | day_x = self.day_embed(x[:, :, 1]) 96 | month_x = self.month_embed(x[:, :, 0]) 97 | 98 | return hour_x + weekday_x + day_x + month_x + minute_x 99 | 100 | 101 | class TimeFeatureEmbedding(nn.Module): 102 | def __init__(self, d_model, embed_type="timeF", freq="h"): 103 | super(TimeFeatureEmbedding, self).__init__() 104 | 105 | freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} 106 | d_inp = freq_map[freq] 107 | self.embed = nn.Linear(d_inp, d_model, bias=False) 108 | 109 | def forward(self, x): 110 | return self.embed(x) 111 | 112 | 113 | class DataEmbedding(nn.Module): 114 | def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): 115 | super(DataEmbedding, self).__init__() 116 | 117 | self.value_embedding = TokenEmbedding(c_in, d_model) 118 | self.position_embedding = PositionalEmbedding(d_model) 119 | self.temporal_embedding = ( 120 | TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 121 | if embed_type != "timeF" 122 | else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 123 | ) 124 | self.dropout = nn.Dropout(p=dropout) 125 | 126 | def forward(self, x, x_mark): 127 | # x: B L W 128 | if x_mark is None: 129 | x = self.value_embedding(x) + self.position_embedding(x) 130 | else: 131 | x = ( 132 | self.value_embedding(x) 133 | + self.temporal_embedding(x_mark) 134 | + self.position_embedding(x) 135 | ) 136 | return self.dropout(x) 137 | 138 | 139 | class SeriesEmbedding(nn.Module): 140 | def __init__(self, c_in, d_model, dropout=0.1): 141 | super().__init__() 142 | self.value_embedding = nn.Linear(c_in, d_model) 143 | self.dropout = nn.Dropout(p=dropout) 144 | 145 | def forward(self, x, x_mark=None): 146 | x = x.transpose(1, 2) 147 | # x: [Batch Variate Time] 148 | 149 | # the potential to take covariates (e.g. timestamps) as tokens 150 | if x_mark is not None: 151 | x = torch.cat([x, x_mark.transpose(1, 2)], 1) 152 | 153 | x = self.value_embedding(x) 154 | # x: [Batch Variate d_model] 155 | 156 | return self.dropout(x) 157 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: run.py weather", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/run.py", 12 | "args": [ 13 | "--is_training", 14 | "1", 15 | "--model_id", 16 | "DEBUG", 17 | "--model", 18 | "Mamba", 19 | "--data", 20 | "custom", 21 | "--root_path", 22 | "./dataset/weather", 23 | "--data_path", 24 | "weather.csv", 25 | "--enc_in", 26 | "21", 27 | "--dec_in", 28 | "21", 29 | "--c_out", 30 | "21", 31 | "--pred_len", 32 | "720", 33 | "--train_epochs", 34 | "10", 35 | "--warmup_epochs", 36 | "0", 37 | "--batch_size", 38 | "16", 39 | ], 40 | "cwd": "${workspaceFolder}", 41 | "console": "integratedTerminal", 42 | "justMyCode": true 43 | }, 44 | { 45 | "name": "Python: run.py etth1", 46 | "type": "debugpy", 47 | "request": "launch", 48 | "program": "${workspaceFolder}/run.py", 49 | "args": [ 50 | "--is_training", 51 | "1", 52 | "--model_id", 53 | "DEBUG", 54 | "--model", 55 | "TimesMamba", 56 | "--data", 57 | "custom", 58 | "--root_path", 59 | "./dataset/ETT-small/", 60 | "--data_path", 61 | "ETTh1.csv", 62 | "--enc_in", 63 | "7", 64 | "--dec_in", 65 | "7", 66 | "--c_out", 67 | "7", 68 | "--pred_len", 69 | "96", 70 | "--d_model", 71 | "128", 72 | "--train_epochs", 73 | "10", 74 | "--batch_size", 75 | "32", 76 | "--learning_rate", 77 | "1e-3", 78 | ], 79 | "cwd": "${workspaceFolder}", 80 | "console": "integratedTerminal", 81 | "justMyCode": true 82 | }, 83 | { 84 | "name": "Python: run.py electricity", 85 | "type": "debugpy", 86 | "request": "launch", 87 | "program": "${workspaceFolder}/run.py", 88 | "args": [ 89 | "--is_training", 90 | "1", 91 | "--model_id", 92 | "DEBUG", 93 | "--model", 94 | "TimesMamba", 95 | "--data", 96 | "custom", 97 | "--root_path", 98 | "./dataset/electricity", 99 | "--data_path", 100 | "electricity.csv", 101 | "--enc_in", 102 | "321", 103 | "--dec_in", 104 | "321", 105 | "--c_out", 106 | "321", 107 | "--pred_len", 108 | "720", 109 | "--d_model", 110 | "256", 111 | "--batch_size", 112 | "32", 113 | "--e_layers", 114 | "3", 115 | "--learning_rate", 116 | "1e-3", 117 | ], 118 | "cwd": "${workspaceFolder}", 119 | "console": "integratedTerminal", 120 | "justMyCode": false 121 | }, 122 | { 123 | "name": "Python: run.py traffic", 124 | "type": "debugpy", 125 | "request": "launch", 126 | "program": "${workspaceFolder}/run.py", 127 | "args": [ 128 | "--is_training", 129 | "1", 130 | "--model_id", 131 | "DEBUG", 132 | "--model", 133 | "TimesMamba", 134 | "--data", 135 | "custom", 136 | "--root_path", 137 | "./dataset/traffic", 138 | "--data_path", 139 | "traffic.csv", 140 | "--enc_in", 141 | "862", 142 | "--dec_in", 143 | "862", 144 | "--c_out", 145 | "862", 146 | "--pred_len", 147 | "720", 148 | "--d_model", 149 | "512", 150 | "--batch_size", 151 | "32", 152 | "--e_layers", 153 | "3", 154 | "--learning_rate", 155 | "1e-3", 156 | ], 157 | "cwd": "${workspaceFolder}", 158 | "console": "integratedTerminal", 159 | "justMyCode": false 160 | }, 161 | { 162 | "name": "Python: run.py PEMS08", 163 | "type": "debugpy", 164 | "request": "launch", 165 | "program": "${workspaceFolder}/run.py", 166 | "args": [ 167 | "--is_training", 168 | "1", 169 | "--model_id", 170 | "DEBUG", 171 | "--model", 172 | "Mamba", 173 | "--data", 174 | "PEMS", 175 | "--root_path", 176 | "./dataset/PEMS", 177 | "--data_path", 178 | "PEMS08.npz", 179 | "--seq_len", 180 | "96", 181 | "--pred_len", 182 | "12", 183 | "--enc_in", 184 | "170", 185 | "--dec_in", 186 | "170", 187 | "--c_out", 188 | "170", 189 | "--e_layers", 190 | "4", 191 | "--batch_size", 192 | "16", 193 | "--learning_rate", 194 | "1e-4", 195 | ], 196 | "cwd": "${workspaceFolder}", 197 | "console": "integratedTerminal", 198 | "justMyCode": true 199 | } 200 | ] 201 | } -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from experiments.exp_long_term_forecasting import Exp_Long_Term_Forecast 4 | import random 5 | import numpy as np 6 | 7 | if __name__ == "__main__": 8 | fix_seed = 2024 9 | random.seed(fix_seed) 10 | torch.manual_seed(fix_seed) 11 | np.random.seed(fix_seed) 12 | 13 | parser = argparse.ArgumentParser(description="Mamba") 14 | 15 | # basic config 16 | parser.add_argument( 17 | "--is_training", type=int, required=True, default=1, help="status" 18 | ) 19 | parser.add_argument( 20 | "--model_id", type=str, required=True, default="test", help="model id" 21 | ) 22 | parser.add_argument( 23 | "--model", 24 | type=str, 25 | required=True, 26 | default="TimesMamba", 27 | help="model name, options: [TimesMamba]", 28 | ) 29 | 30 | # data loader 31 | parser.add_argument( 32 | "--data", type=str, required=True, default="custom", help="dataset type" 33 | ) 34 | parser.add_argument( 35 | "--root_path", 36 | type=str, 37 | default="./data/electricity/", 38 | help="root path of the data file", 39 | ) 40 | parser.add_argument( 41 | "--data_path", type=str, default="electricity.csv", help="data csv file" 42 | ) 43 | parser.add_argument( 44 | "--features", 45 | type=str, 46 | default="M", 47 | help="forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate", 48 | ) 49 | parser.add_argument( 50 | "--target", type=str, default="OT", help="target feature in S or MS task" 51 | ) 52 | parser.add_argument( 53 | "--freq", 54 | type=str, 55 | default="h", 56 | help="freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h", 57 | ) 58 | parser.add_argument( 59 | "--checkpoints", 60 | type=str, 61 | default="./checkpoints/", 62 | help="location of model checkpoints", 63 | ) 64 | 65 | # forecasting task 66 | parser.add_argument("--seq_len", type=int, default=96, help="input sequence length") 67 | parser.add_argument( 68 | "--label_len", type=int, default=48, help="start token length" 69 | ) # no longer needed in inverted Transformers 70 | parser.add_argument( 71 | "--pred_len", type=int, default=96, help="prediction sequence length" 72 | ) 73 | 74 | # model define 75 | parser.add_argument( 76 | "--no_norm", dest="use_norm", action="store_false", help="use norm and denorm" 77 | ) 78 | parser.add_argument("--revin_affine", action="store_true", help="RevIN affine") 79 | parser.add_argument("--use_mark", action="store_true", help="use timestamp feature") 80 | parser.add_argument( 81 | "--channel_independence", 82 | action="store_true", 83 | help="whether to use channel_independence mechanism", 84 | ) 85 | parser.add_argument("--enc_in", type=int, default=7, help="encoder input size") 86 | parser.add_argument( 87 | "--enc_m_in", type=int, default=4, help="encoder data marker input size" 88 | ) 89 | parser.add_argument("--dec_in", type=int, default=7, help="decoder input size") 90 | parser.add_argument( 91 | "--c_out", type=int, default=7, help="output size" 92 | ) # applicable on arbitrary number of variates in inverted Transformers 93 | parser.add_argument("--d_model", type=int, default=512, help="dimension of model") 94 | parser.add_argument("--ssm_expand", type=int, default=1, help="expand factor") 95 | parser.add_argument("--n_heads", type=int, default=8, help="num of heads") 96 | parser.add_argument("--e_layers", type=int, default=2, help="num of encoder layers") 97 | parser.add_argument("--d_layers", type=int, default=1, help="num of decoder layers") 98 | parser.add_argument( 99 | "--r_ff", type=int, default=4, help="ratio ffn hidden dimension / d_model" 100 | ) 101 | parser.add_argument( 102 | "--moving_avg", type=int, default=25, help="window size of moving average" 103 | ) 104 | parser.add_argument("--factor", type=int, default=1, help="attn factor") 105 | parser.add_argument( 106 | "--nodistil", 107 | dest="distil", 108 | action="store_false", 109 | help="whether to use distilling in encoder, using this argument means not using distilling", 110 | default=True, 111 | ) 112 | parser.add_argument("--dropout", type=float, default=0.1, help="dropout") 113 | parser.add_argument( 114 | "--embed", 115 | type=str, 116 | default="timeF", 117 | help="time features encoding, options:[timeF, fixed, learned]", 118 | ) 119 | parser.add_argument("--activation", type=str, default="gelu", help="activation") 120 | parser.add_argument( 121 | "--output_attention", 122 | action="store_true", 123 | help="whether to output attention in ecoder", 124 | ) 125 | parser.add_argument( 126 | "--do_predict", 127 | action="store_true", 128 | help="whether to predict unseen future data", 129 | ) 130 | 131 | # optimization 132 | parser.add_argument( 133 | "--num_workers", type=int, default=10, help="data loader num workers" 134 | ) 135 | parser.add_argument("--itr", type=int, default=1, help="experiments times") 136 | parser.add_argument("--train_epochs", type=int, default=10, help="train epochs") 137 | parser.add_argument( 138 | "--batch_size", type=int, default=32, help="batch size of train input data" 139 | ) 140 | parser.add_argument("--warmup_epochs", type=int, default=0, help="warmup epochs") 141 | parser.add_argument( 142 | "--patience", type=int, default=-1, help="early stopping patience" 143 | ) 144 | parser.add_argument( 145 | "--learning_rate", type=float, default=0.0001, help="optimizer learning rate" 146 | ) 147 | parser.add_argument("--des", type=str, default="test", help="exp description") 148 | parser.add_argument("--loss", type=str, default="MSE", help="loss function") 149 | parser.add_argument( 150 | "--lradj", type=str, default="type1", help="adjust learning rate" 151 | ) 152 | parser.add_argument( 153 | "--use_amp", 154 | action="store_true", 155 | help="use automatic mixed precision training", 156 | default=False, 157 | ) 158 | 159 | # GPU 160 | parser.add_argument("--use_gpu", type=bool, default=True, help="use gpu") 161 | parser.add_argument( 162 | "--device", type=str, default="0", help="device ids of multile gpus" 163 | ) 164 | 165 | # experiment 166 | parser.add_argument( 167 | "--exp_name", 168 | type=str, 169 | required=False, 170 | default="MTSF", 171 | help="experiemnt name, options:[MTSF, partial_train]", 172 | ) 173 | parser.add_argument("--inverse", action="store_true", help="inverse output data") 174 | parser.add_argument( 175 | "--class_strategy", 176 | type=str, 177 | default="projection", 178 | help="projection/average/cls_token", 179 | ) 180 | parser.add_argument( 181 | "--efficient_training", 182 | type=bool, 183 | default=False, 184 | help="whether to use efficient_training (exp_name should be partial train)", 185 | ) 186 | parser.add_argument( 187 | "--partial_start_index", 188 | type=int, 189 | default=0, 190 | help="the start index of variates for partial training, " 191 | "you can select [partial_start_index, min(enc_in + partial_start_index, N)]", 192 | ) 193 | 194 | args = parser.parse_args() 195 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 196 | 197 | print("Args in experiment:") 198 | print(args) 199 | 200 | # MTSF: multivariate time series forecasting 201 | Exp = Exp_Long_Term_Forecast 202 | 203 | if args.is_training: 204 | for ii in range(args.itr): 205 | # setting record of experiments 206 | setting = f"{args.model_id}_{args.model}_{args.data}_{ii}_ft{args.features}_sl{args.seq_len}_ll{args.label_len}_pl{args.pred_len}_dm{args.d_model}_nh{args.n_heads}_el{args.e_layers}_dl{args.d_layers}_df{args.r_ff}_fc{args.factor}_eb{args.embed}_dt{args.distil}_{args.des}_{args.class_strategy}" 207 | 208 | exp = Exp(args) # set experiments 209 | print(f">>>>>>>start training: {setting}>>>>>>>>>>>>>>>>>>>>>>>>>>") 210 | exp.train(setting) 211 | 212 | print(f">>>>>>>testing: {setting}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<") 213 | exp.test(setting) 214 | 215 | if args.do_predict: 216 | print(f">>>>>>>predicting: {setting}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<") 217 | exp.predict(setting, True) 218 | 219 | torch.cuda.empty_cache() 220 | else: 221 | ii = 0 222 | setting = f"{args.model_id}_{args.model}_{args.data}_{ii}_ft{args.features}_sl{args.seq_len}_ll{args.label_len}_pl{args.pred_len}_dm{args.d_model}_nh{args.n_heads}_el{args.e_layers}_dl{args.d_layers}_df{args.r_ff}_fc{args.factor}_eb{args.embed}_dt{args.distil}_{args.des}_{args.class_strategy}" 223 | 224 | exp = Exp(args) # set experiments 225 | print(">>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<".format(setting)) 226 | exp.test(setting, test=1) 227 | torch.cuda.empty_cache() 228 | -------------------------------------------------------------------------------- /ThirdPartyNotices.txt: -------------------------------------------------------------------------------- 1 | NOTICES 2 | 3 | This repository incorporates material as listed below or described in the code. 4 | 5 | --------------------------------------------------------- 6 | 7 | @ts-kim/RevIN - MIT 8 | https://github.com/ts-kim/RevIN 9 | 10 | MIT License 11 | 12 | Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI) 13 | 14 | Permission is hereby granted, free of charge, to any person obtaining a copy 15 | of this software and associated documentation files (the "Software"), to deal 16 | in the Software without restriction, including without limitation the rights 17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | copies of the Software, and to permit persons to whom the Software is 19 | furnished to do so, subject to the following conditions: 20 | 21 | The above copyright notice and this permission notice shall be included in all 22 | copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | SOFTWARE. 31 | 32 | --------------------------------------------------------- 33 | 34 | --------------------------------------------------------- 35 | 36 | @thuml/iTransformer - MIT 37 | https://github.com/thuml/iTransformer 38 | 39 | MIT License 40 | 41 | Copyright (c) 2022 THUML @ Tsinghua University 42 | 43 | Permission is hereby granted, free of charge, to any person obtaining a copy 44 | of this software and associated documentation files (the "Software"), to deal 45 | in the Software without restriction, including without limitation the rights 46 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 47 | copies of the Software, and to permit persons to whom the Software is 48 | furnished to do so, subject to the following conditions: 49 | 50 | The above copyright notice and this permission notice shall be included in all 51 | copies or substantial portions of the Software. 52 | 53 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 54 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 55 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 56 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 57 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 58 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 59 | SOFTWARE. 60 | 61 | --------------------------------------------------------- 62 | 63 | --------------------------------------------------------- 64 | 65 | @state-spaces/mamba - Apache License V2.0 66 | https://github.com/state-spaces/mamba 67 | 68 | Apache License 69 | Version 2.0, January 2004 70 | http://www.apache.org/licenses/ 71 | 72 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 73 | 74 | 1. Definitions. 75 | 76 | "License" shall mean the terms and conditions for use, reproduction, 77 | and distribution as defined by Sections 1 through 9 of this document. 78 | 79 | "Licensor" shall mean the copyright owner or entity authorized by 80 | the copyright owner that is granting the License. 81 | 82 | "Legal Entity" shall mean the union of the acting entity and all 83 | other entities that control, are controlled by, or are under common 84 | control with that entity. For the purposes of this definition, 85 | "control" means (i) the power, direct or indirect, to cause the 86 | direction or management of such entity, whether by contract or 87 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 88 | outstanding shares, or (iii) beneficial ownership of such entity. 89 | 90 | "You" (or "Your") shall mean an individual or Legal Entity 91 | exercising permissions granted by this License. 92 | 93 | "Source" form shall mean the preferred form for making modifications, 94 | including but not limited to software source code, documentation 95 | source, and configuration files. 96 | 97 | "Object" form shall mean any form resulting from mechanical 98 | transformation or translation of a Source form, including but 99 | not limited to compiled object code, generated documentation, 100 | and conversions to other media types. 101 | 102 | "Work" shall mean the work of authorship, whether in Source or 103 | Object form, made available under the License, as indicated by a 104 | copyright notice that is included in or attached to the work 105 | (an example is provided in the Appendix below). 106 | 107 | "Derivative Works" shall mean any work, whether in Source or Object 108 | form, that is based on (or derived from) the Work and for which the 109 | editorial revisions, annotations, elaborations, or other modifications 110 | represent, as a whole, an original work of authorship. For the purposes 111 | of this License, Derivative Works shall not include works that remain 112 | separable from, or merely link (or bind by name) to the interfaces of, 113 | the Work and Derivative Works thereof. 114 | 115 | "Contribution" shall mean any work of authorship, including 116 | the original version of the Work and any modifications or additions 117 | to that Work or Derivative Works thereof, that is intentionally 118 | submitted to Licensor for inclusion in the Work by the copyright owner 119 | or by an individual or Legal Entity authorized to submit on behalf of 120 | the copyright owner. For the purposes of this definition, "submitted" 121 | means any form of electronic, verbal, or written communication sent 122 | to the Licensor or its representatives, including but not limited to 123 | communication on electronic mailing lists, source code control systems, 124 | and issue tracking systems that are managed by, or on behalf of, the 125 | Licensor for the purpose of discussing and improving the Work, but 126 | excluding communication that is conspicuously marked or otherwise 127 | designated in writing by the copyright owner as "Not a Contribution." 128 | 129 | "Contributor" shall mean Licensor and any individual or Legal Entity 130 | on behalf of whom a Contribution has been received by Licensor and 131 | subsequently incorporated within the Work. 132 | 133 | 2. Grant of Copyright License. Subject to the terms and conditions of 134 | this License, each Contributor hereby grants to You a perpetual, 135 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 136 | copyright license to reproduce, prepare Derivative Works of, 137 | publicly display, publicly perform, sublicense, and distribute the 138 | Work and such Derivative Works in Source or Object form. 139 | 140 | 3. Grant of Patent License. Subject to the terms and conditions of 141 | this License, each Contributor hereby grants to You a perpetual, 142 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 143 | (except as stated in this section) patent license to make, have made, 144 | use, offer to sell, sell, import, and otherwise transfer the Work, 145 | where such license applies only to those patent claims licensable 146 | by such Contributor that are necessarily infringed by their 147 | Contribution(s) alone or by combination of their Contribution(s) 148 | with the Work to which such Contribution(s) was submitted. If You 149 | institute patent litigation against any entity (including a 150 | cross-claim or counterclaim in a lawsuit) alleging that the Work 151 | or a Contribution incorporated within the Work constitutes direct 152 | or contributory patent infringement, then any patent licenses 153 | granted to You under this License for that Work shall terminate 154 | as of the date such litigation is filed. 155 | 156 | 4. Redistribution. You may reproduce and distribute copies of the 157 | Work or Derivative Works thereof in any medium, with or without 158 | modifications, and in Source or Object form, provided that You 159 | meet the following conditions: 160 | 161 | (a) You must give any other recipients of the Work or 162 | Derivative Works a copy of this License; and 163 | 164 | (b) You must cause any modified files to carry prominent notices 165 | stating that You changed the files; and 166 | 167 | (c) You must retain, in the Source form of any Derivative Works 168 | that You distribute, all copyright, patent, trademark, and 169 | attribution notices from the Source form of the Work, 170 | excluding those notices that do not pertain to any part of 171 | the Derivative Works; and 172 | 173 | (d) If the Work includes a "NOTICE" text file as part of its 174 | distribution, then any Derivative Works that You distribute must 175 | include a readable copy of the attribution notices contained 176 | within such NOTICE file, excluding those notices that do not 177 | pertain to any part of the Derivative Works, in at least one 178 | of the following places: within a NOTICE text file distributed 179 | as part of the Derivative Works; within the Source form or 180 | documentation, if provided along with the Derivative Works; or, 181 | within a display generated by the Derivative Works, if and 182 | wherever such third-party notices normally appear. The contents 183 | of the NOTICE file are for informational purposes only and 184 | do not modify the License. You may add Your own attribution 185 | notices within Derivative Works that You distribute, alongside 186 | or as an addendum to the NOTICE text from the Work, provided 187 | that such additional attribution notices cannot be construed 188 | as modifying the License. 189 | 190 | You may add Your own copyright statement to Your modifications and 191 | may provide additional or different license terms and conditions 192 | for use, reproduction, or distribution of Your modifications, or 193 | for any such Derivative Works as a whole, provided Your use, 194 | reproduction, and distribution of the Work otherwise complies with 195 | the conditions stated in this License. 196 | 197 | 5. Submission of Contributions. Unless You explicitly state otherwise, 198 | any Contribution intentionally submitted for inclusion in the Work 199 | by You to the Licensor shall be under the terms and conditions of 200 | this License, without any additional terms or conditions. 201 | Notwithstanding the above, nothing herein shall supersede or modify 202 | the terms of any separate license agreement you may have executed 203 | with Licensor regarding such Contributions. 204 | 205 | 6. Trademarks. This License does not grant permission to use the trade 206 | names, trademarks, service marks, or product names of the Licensor, 207 | except as required for reasonable and customary use in describing the 208 | origin of the Work and reproducing the content of the NOTICE file. 209 | 210 | 7. Disclaimer of Warranty. Unless required by applicable law or 211 | agreed to in writing, Licensor provides the Work (and each 212 | Contributor provides its Contributions) on an "AS IS" BASIS, 213 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 214 | implied, including, without limitation, any warranties or conditions 215 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 216 | PARTICULAR PURPOSE. You are solely responsible for determining the 217 | appropriateness of using or redistributing the Work and assume any 218 | risks associated with Your exercise of permissions under this License. 219 | 220 | 8. Limitation of Liability. In no event and under no legal theory, 221 | whether in tort (including negligence), contract, or otherwise, 222 | unless required by applicable law (such as deliberate and grossly 223 | negligent acts) or agreed to in writing, shall any Contributor be 224 | liable to You for damages, including any direct, indirect, special, 225 | incidental, or consequential damages of any character arising as a 226 | result of this License or out of the use or inability to use the 227 | Work (including but not limited to damages for loss of goodwill, 228 | work stoppage, computer failure or malfunction, or any and all 229 | other commercial damages or losses), even if such Contributor 230 | has been advised of the possibility of such damages. 231 | 232 | 9. Accepting Warranty or Additional Liability. While redistributing 233 | the Work or Derivative Works thereof, You may choose to offer, 234 | and charge a fee for, acceptance of support, warranty, indemnity, 235 | or other liability obligations and/or rights consistent with this 236 | License. However, in accepting such obligations, You may act only 237 | on Your own behalf and on Your sole responsibility, not on behalf 238 | of any other Contributor, and only if You agree to indemnify, 239 | defend, and hold each Contributor harmless for any liability 240 | incurred by, or claims asserted against, such Contributor by reason 241 | of your accepting any such warranty or additional liability. 242 | 243 | END OF TERMS AND CONDITIONS 244 | 245 | APPENDIX: How to apply the Apache License to your work. 246 | 247 | To apply the Apache License to your work, attach the following 248 | boilerplate notice, with the fields enclosed by brackets "[]" 249 | replaced with your own identifying information. (Don't include 250 | the brackets!) The text should be enclosed in the appropriate 251 | comment syntax for the file format. We also recommend that a 252 | file or class name and description of purpose be included on the 253 | same "printed page" as the copyright notice for easier 254 | identification within third-party archives. 255 | 256 | Copyright 2023 Tri Dao, Albert Gu 257 | 258 | Licensed under the Apache License, Version 2.0 (the "License"); 259 | you may not use this file except in compliance with the License. 260 | You may obtain a copy of the License at 261 | 262 | http://www.apache.org/licenses/LICENSE-2.0 263 | 264 | Unless required by applicable law or agreed to in writing, software 265 | distributed under the License is distributed on an "AS IS" BASIS, 266 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 267 | See the License for the specific language governing permissions and 268 | limitations under the License. 269 | -------------------------------------------------------------------------------- /experiments/exp_long_term_forecasting.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from experiments.exp_basic import Exp_Basic 3 | from utils.tools import EarlyStopping, get_lr_scheduler, visual 4 | from utils.metrics import metric 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | import os 9 | import time 10 | import numpy as np 11 | import loggingutil 12 | 13 | logger = loggingutil.get_logger(__name__) 14 | 15 | 16 | class Exp_Long_Term_Forecast(Exp_Basic): 17 | def __init__(self, args): 18 | super(Exp_Long_Term_Forecast, self).__init__(args) 19 | 20 | def _build_model(self): 21 | model = self.model_dict[self.args.model].Model(self.args).float() 22 | return model 23 | 24 | def _get_data(self, flag): 25 | data_set, data_loader = data_provider(self.args, flag) 26 | return data_set, data_loader 27 | 28 | def _select_optimizer(self): 29 | model_optim = optim.AdamW( 30 | self.model.parameters(), 31 | lr=self.args.learning_rate, 32 | betas=(0.9, 0.95), 33 | weight_decay=0.05, 34 | ) 35 | return model_optim 36 | 37 | def _select_criterion(self): 38 | return nn.MSELoss() 39 | 40 | def vali(self, vali_data, vali_loader, criterion): 41 | self.model.eval() 42 | 43 | total_loss = [] 44 | 45 | with torch.no_grad(): 46 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate( 47 | vali_loader 48 | ): 49 | batch_x = batch_x.float().to(self.device) 50 | batch_y = batch_y.float() 51 | 52 | if "PEMS" in self.args.data or "Solar" in self.args.data: 53 | batch_x_mark = None 54 | batch_y_mark = None 55 | else: 56 | batch_x_mark = batch_x_mark.float().to(self.device) 57 | batch_y_mark = batch_y_mark.float().to(self.device) 58 | 59 | # decoder input 60 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len :, :]).float() 61 | dec_inp = ( 62 | torch.cat([batch_y[:, : self.args.label_len, :], dec_inp], dim=1) 63 | .float() 64 | .to(self.device) 65 | ) 66 | 67 | # encoder - decoder 68 | if self.args.use_amp: 69 | with torch.cuda.amp.autocast(): 70 | if self.args.output_attention: 71 | outputs = self.model( 72 | batch_x, batch_x_mark, dec_inp, batch_y_mark 73 | )[0] 74 | else: 75 | outputs = self.model( 76 | batch_x, batch_x_mark, dec_inp, batch_y_mark 77 | ) 78 | else: 79 | if self.args.output_attention: 80 | outputs = self.model( 81 | batch_x, batch_x_mark, dec_inp, batch_y_mark 82 | )[0] 83 | else: 84 | outputs = self.model( 85 | batch_x, batch_x_mark, dec_inp, batch_y_mark 86 | ) 87 | 88 | f_dim = -1 if self.args.features == "MS" else 0 89 | outputs = outputs[:, -self.args.pred_len :, f_dim:] 90 | batch_y = batch_y[:, -self.args.pred_len :, f_dim:] 91 | 92 | pred = outputs.detach().cpu() 93 | true = batch_y.detach().cpu() 94 | 95 | loss = criterion(pred, true) * len(batch_x) 96 | 97 | total_loss.append(loss.item()) 98 | 99 | total_samples = len(vali_loader.dataset) 100 | if vali_loader.drop_last: 101 | total_samples -= len(vali_loader.dataset) % vali_loader.batch_size 102 | total_loss = np.sum(total_loss) / total_samples 103 | 104 | return total_loss 105 | 106 | def train(self, setting): 107 | _, train_loader = self._get_data(flag="train") 108 | vali_data, vali_loader = self._get_data(flag="val") 109 | test_data, test_loader = self._get_data(flag="test") 110 | 111 | path = os.path.join(self.args.checkpoints, setting) 112 | if not os.path.exists(path): 113 | os.makedirs(path) 114 | 115 | time_now = time.time() 116 | 117 | train_steps = len(train_loader) 118 | early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) 119 | 120 | model_optim = self._select_optimizer() 121 | criterion = self._select_criterion() 122 | 123 | lr_scheduler = get_lr_scheduler( 124 | model_optim, 125 | self.args.train_epochs, 126 | self.args.warmup_epochs, 127 | self.args.lradj, 128 | ) 129 | 130 | if self.args.use_amp: 131 | scaler = torch.cuda.amp.GradScaler() 132 | 133 | for epoch in range(self.args.train_epochs): 134 | iter_count = 0 135 | train_loss = [] 136 | 137 | print(f"Learning rate {model_optim.param_groups[0]['lr']}") 138 | 139 | self.model.train() 140 | epoch_time = time.time() 141 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate( 142 | train_loader 143 | ): 144 | iter_count += 1 145 | model_optim.zero_grad() 146 | batch_x = batch_x.float().to(self.device) 147 | 148 | batch_y = batch_y.float().to(self.device) 149 | if "PEMS" in self.args.data or "Solar" in self.args.data: 150 | batch_x_mark = None 151 | batch_y_mark = None 152 | else: 153 | batch_x_mark = batch_x_mark.float().to(self.device) 154 | batch_y_mark = batch_y_mark.float().to(self.device) 155 | 156 | # decoder input 157 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len :, :]).float() 158 | dec_inp = ( 159 | torch.cat([batch_y[:, : self.args.label_len, :], dec_inp], dim=1) 160 | .float() 161 | .to(self.device) 162 | ) 163 | 164 | # encoder - decoder 165 | if self.args.use_amp: 166 | with torch.cuda.amp.autocast(): 167 | if self.args.output_attention: 168 | outputs = self.model( 169 | batch_x, batch_x_mark, dec_inp, batch_y_mark 170 | )[0] 171 | else: 172 | outputs = self.model( 173 | batch_x, batch_x_mark, dec_inp, batch_y_mark 174 | ) 175 | 176 | f_dim = -1 if self.args.features == "MS" else 0 177 | outputs = outputs[:, -self.args.pred_len :, f_dim:] 178 | batch_y = batch_y[:, -self.args.pred_len :, f_dim:] 179 | loss = criterion(outputs, batch_y) 180 | train_loss.append(loss.item()) 181 | else: 182 | if self.args.output_attention: 183 | outputs = self.model( 184 | batch_x, batch_x_mark, dec_inp, batch_y_mark 185 | )[0] 186 | else: 187 | outputs = self.model( 188 | batch_x, batch_x_mark, dec_inp, batch_y_mark 189 | ) 190 | 191 | f_dim = -1 if self.args.features == "MS" else 0 192 | outputs = outputs[:, -self.args.pred_len :, f_dim:] 193 | batch_y = batch_y[:, -self.args.pred_len :, f_dim:] 194 | loss = criterion(outputs, batch_y) 195 | train_loss.append(loss.item()) 196 | 197 | if (i + 1) % 100 == 0: 198 | logger.info( 199 | "\titers: {0}, epoch: {1} | loss: {2:.7f}".format( 200 | i + 1, epoch + 1, loss.item() 201 | ) 202 | ) 203 | speed = (time.time() - time_now) / iter_count 204 | left_time = speed * ( 205 | (self.args.train_epochs - epoch) * train_steps - i 206 | ) 207 | logger.info( 208 | "\tspeed: {:.4f}s/iter; left time: {:.4f}s".format( 209 | speed, left_time 210 | ) 211 | ) 212 | iter_count = 0 213 | time_now = time.time() 214 | 215 | if self.args.use_amp: 216 | scaler.scale(loss).backward() 217 | scaler.step(model_optim) 218 | scaler.update() 219 | else: 220 | loss.backward() 221 | model_optim.step() 222 | 223 | logger.info( 224 | "Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time) 225 | ) 226 | train_loss = np.average(train_loss) 227 | vali_loss = self.vali(vali_data, vali_loader, criterion) 228 | test_loss = self.vali(test_data, test_loader, criterion) 229 | 230 | logger.info( 231 | "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( 232 | epoch + 1, train_steps, train_loss, vali_loss, test_loss 233 | ) 234 | ) 235 | early_stopping(vali_loss, self.model, path) 236 | if early_stopping.early_stop: 237 | logger.info("Early stopping") 238 | break 239 | 240 | lr_scheduler.step() 241 | 242 | best_model_path = path + "/" + "checkpoint.pth" 243 | self.model.load_state_dict(torch.load(best_model_path)) 244 | 245 | mem_info = torch.cuda.mem_get_info(self.device) 246 | mem_usage = (mem_info[1] - mem_info[0]) >> 20 247 | logger.info(f"GPU memory usage: {mem_usage} MB") 248 | 249 | return self.model 250 | 251 | def test(self, setting, test=0): 252 | test_data, test_loader = self._get_data(flag="test") 253 | if test: 254 | logger.info("loading model") 255 | self.model.load_state_dict( 256 | torch.load(os.path.join("./checkpoints/" + setting, "checkpoint.pth")) 257 | ) 258 | 259 | preds = [] 260 | trues = [] 261 | folder_path = "./test_results/" + setting + "/" 262 | if not os.path.exists(folder_path): 263 | os.makedirs(folder_path) 264 | 265 | self.model.eval() 266 | with torch.no_grad(): 267 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate( 268 | test_loader 269 | ): 270 | batch_x = batch_x.float().to(self.device) 271 | batch_y = batch_y.float().to(self.device) 272 | 273 | if "PEMS" in self.args.data or "Solar" in self.args.data: 274 | batch_x_mark = None 275 | batch_y_mark = None 276 | else: 277 | batch_x_mark = batch_x_mark.float().to(self.device) 278 | batch_y_mark = batch_y_mark.float().to(self.device) 279 | 280 | # decoder input 281 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len :, :]).float() 282 | dec_inp = ( 283 | torch.cat([batch_y[:, : self.args.label_len, :], dec_inp], dim=1) 284 | .float() 285 | .to(self.device) 286 | ) 287 | # encoder - decoder 288 | if self.args.use_amp: 289 | with torch.cuda.amp.autocast(): 290 | if self.args.output_attention: 291 | outputs = self.model( 292 | batch_x, batch_x_mark, dec_inp, batch_y_mark 293 | )[0] 294 | else: 295 | outputs = self.model( 296 | batch_x, batch_x_mark, dec_inp, batch_y_mark 297 | ) 298 | else: 299 | if self.args.output_attention: 300 | outputs = self.model( 301 | batch_x, batch_x_mark, dec_inp, batch_y_mark 302 | )[0] 303 | 304 | else: 305 | outputs = self.model( 306 | batch_x, batch_x_mark, dec_inp, batch_y_mark 307 | ) 308 | 309 | f_dim = -1 if self.args.features == "MS" else 0 310 | outputs = outputs[:, -self.args.pred_len :, f_dim:] 311 | batch_y = batch_y[:, -self.args.pred_len :, f_dim:].to(self.device) 312 | outputs = outputs.detach().cpu().numpy() 313 | batch_y = batch_y.detach().cpu().numpy() 314 | if test_data.scale and self.args.inverse: 315 | shape = outputs.shape 316 | outputs = test_data.inverse_transform(outputs.squeeze(0)).reshape( 317 | shape 318 | ) 319 | batch_y = test_data.inverse_transform(batch_y.squeeze(0)).reshape( 320 | shape 321 | ) 322 | 323 | pred = outputs 324 | true = batch_y 325 | 326 | preds.append(pred) 327 | trues.append(true) 328 | if i % 20 == 0: 329 | input = batch_x.detach().cpu().numpy() 330 | if test_data.scale and self.args.inverse: 331 | shape = input.shape 332 | input = test_data.inverse_transform(input.squeeze(0)).reshape( 333 | shape 334 | ) 335 | gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0) 336 | pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0) 337 | visual(gt, pd, os.path.join(folder_path, str(i) + ".pdf")) 338 | 339 | preds = np.concatenate(preds) 340 | trues = np.concatenate(trues) 341 | logger.info(f"test shape: {preds.shape} {trues.shape}") 342 | 343 | mae, mse, rmse, mape, mspe = metric(preds, trues) 344 | logger.info("mse:{}, mae:{}".format(mse, mae)) 345 | f = open("result_long_term_forecast.txt", "a") 346 | f.write(setting + " \n") 347 | f.write("mse:{}, mae:{}".format(mse, mae)) 348 | f.write("\n") 349 | f.write("\n") 350 | f.close() 351 | 352 | # result save 353 | # folder_path = "./results/" + setting + "/" 354 | # if not os.path.exists(folder_path): 355 | # os.makedirs(folder_path) 356 | 357 | # np.save(folder_path + "pred.npy", preds) 358 | # np.save(folder_path + "true.npy", trues) 359 | # np.save(folder_path + "metrics.npy", np.array([mae, mse, rmse, mape, mspe])) 360 | 361 | return 362 | 363 | def predict(self, setting, load=False): 364 | pred_data, pred_loader = self._get_data(flag="pred") 365 | 366 | if load: 367 | path = os.path.join(self.args.checkpoints, setting) 368 | best_model_path = path + "/" + "checkpoint.pth" 369 | self.model.load_state_dict(torch.load(best_model_path)) 370 | 371 | preds = [] 372 | 373 | self.model.eval() 374 | with torch.no_grad(): 375 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate( 376 | pred_loader 377 | ): 378 | batch_x = batch_x.float().to(self.device) 379 | batch_y = batch_y.float() 380 | batch_x_mark = batch_x_mark.float().to(self.device) 381 | batch_y_mark = batch_y_mark.float().to(self.device) 382 | 383 | # decoder input 384 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len :, :]).float() 385 | dec_inp = ( 386 | torch.cat([batch_y[:, : self.args.label_len, :], dec_inp], dim=1) 387 | .float() 388 | .to(self.device) 389 | ) 390 | # encoder - decoder 391 | if self.args.use_amp: 392 | with torch.cuda.amp.autocast(): 393 | if self.args.output_attention: 394 | outputs = self.model( 395 | batch_x, batch_x_mark, dec_inp, batch_y_mark 396 | )[0] 397 | else: 398 | outputs = self.model( 399 | batch_x, batch_x_mark, dec_inp, batch_y_mark 400 | ) 401 | else: 402 | if self.args.output_attention: 403 | outputs = self.model( 404 | batch_x, batch_x_mark, dec_inp, batch_y_mark 405 | )[0] 406 | else: 407 | outputs = self.model( 408 | batch_x, batch_x_mark, dec_inp, batch_y_mark 409 | ) 410 | outputs = outputs.detach().cpu().numpy() 411 | if pred_data.scale and self.args.inverse: 412 | shape = outputs.shape 413 | outputs = pred_data.inverse_transform(outputs.squeeze(0)).reshape( 414 | shape 415 | ) 416 | preds.append(outputs) 417 | 418 | preds = np.array(preds) 419 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 420 | 421 | # result save 422 | folder_path = "./results/" + setting + "/" 423 | if not os.path.exists(folder_path): 424 | os.makedirs(folder_path) 425 | 426 | np.save(folder_path + "real_prediction.npy", preds) 427 | 428 | return 429 | -------------------------------------------------------------------------------- /model/mambacore.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import repeat 6 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn 7 | 8 | 9 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 10 | if drop_prob == 0.0 or not training: 11 | return x 12 | keep_prob = 1 - drop_prob 13 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 14 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 15 | random_tensor.floor_() 16 | output = x.div(keep_prob) * random_tensor 17 | return output 18 | 19 | 20 | class DropPath(nn.Module): 21 | def __init__(self, drop_prob=None): 22 | super(DropPath, self).__init__() 23 | self.drop_prob = drop_prob 24 | 25 | def forward(self, x): 26 | return drop_path(x, self.drop_prob, self.training) 27 | 28 | 29 | class QuadMamba(nn.Module): 30 | def __init__( 31 | self, 32 | d_model=96, 33 | expand=2, 34 | d_state=16, 35 | d_conv=3, 36 | n_scan_directions=2, 37 | dt_rank="auto", 38 | dt_min=0.001, 39 | dt_max=0.1, 40 | dt_init="random", 41 | dt_scale=1.0, 42 | dt_init_floor=1e-4, 43 | dropout=0.1, 44 | conv_bias=True, 45 | bias=False, 46 | device=None, 47 | dtype=None, 48 | ): 49 | factory_kwargs = {"device": device, "dtype": dtype} 50 | super().__init__() 51 | self.d_model = d_model 52 | self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_state 53 | self.d_conv = d_conv 54 | self.n_scan_directions = n_scan_directions 55 | self.expand = expand 56 | self.d_inner = int(self.expand * self.d_model) 57 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 58 | 59 | self.in_proj = nn.Linear( 60 | self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs 61 | ) 62 | # x, z; z for residual 63 | 64 | self.dwconv = nn.Conv2d( 65 | in_channels=self.d_inner, 66 | out_channels=self.d_inner, 67 | groups=self.d_inner, 68 | kernel_size=d_conv, 69 | padding=(d_conv - 1) // 2, 70 | bias=conv_bias, 71 | **factory_kwargs, 72 | ) 73 | 74 | self.act = nn.SiLU() 75 | 76 | x_proj = [ 77 | nn.Linear( 78 | self.d_inner, (self.dt_rank + d_state * 2), bias=False, **factory_kwargs 79 | ) 80 | for _ in range(n_scan_directions) 81 | ] 82 | # dim(dts, Bs, Cs) = dt_rank, d_state, d_state 83 | self.x_proj_weight = nn.Parameter( 84 | torch.stack([t.weight for t in x_proj], dim=0) 85 | ) # (K, dt_rank + d_state*2, d_inner) 86 | 87 | self.dt_projs = [ 88 | self.dt_init( 89 | self.dt_rank, 90 | self.d_inner, 91 | dt_scale, 92 | dt_init, 93 | dt_min, 94 | dt_max, 95 | dt_init_floor, 96 | **factory_kwargs, 97 | ) 98 | for _ in range(n_scan_directions) 99 | ] 100 | self.dt_projs_weight = nn.Parameter( 101 | torch.stack([t.weight for t in self.dt_projs], dim=0) 102 | ) # (K, inner, rank) 103 | self.dt_projs_bias = nn.Parameter( 104 | torch.stack([t.bias for t in self.dt_projs], dim=0) 105 | ) # (K, inner) 106 | del self.dt_projs 107 | 108 | self.A_logs = self.A_log_init( 109 | self.d_state, self.d_inner, copies=n_scan_directions, merge=True 110 | ) # (K * D, N) 111 | self.Ds = self.D_init( 112 | self.d_inner, copies=n_scan_directions, merge=True 113 | ) # (K * D) 114 | 115 | self.selective_scan = selective_scan_fn 116 | 117 | self.out_proj = nn.Identity() 118 | if self.d_inner != self.d_model: 119 | self.out_proj = nn.Linear( 120 | self.d_inner, self.d_model, bias=bias, **factory_kwargs 121 | ) 122 | self.drop = nn.Dropout(dropout) 123 | 124 | @staticmethod 125 | def dt_init( 126 | dt_rank, 127 | d_inner, 128 | dt_scale=1.0, 129 | dt_init="random", 130 | dt_min=0.001, 131 | dt_max=0.1, 132 | dt_init_floor=1e-4, 133 | **factory_kwargs, 134 | ): 135 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 136 | 137 | # Initialize special dt projection to preserve variance at initialization 138 | dt_init_std = dt_rank**-0.5 * dt_scale 139 | if dt_init == "constant": 140 | nn.init.constant_(dt_proj.weight, dt_init_std) 141 | elif dt_init == "random": 142 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 143 | else: 144 | raise NotImplementedError 145 | 146 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 147 | dt = torch.exp( 148 | torch.rand(d_inner, **factory_kwargs) 149 | * (math.log(dt_max) - math.log(dt_min)) 150 | + math.log(dt_min) 151 | ).clamp(min=dt_init_floor) 152 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 153 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 154 | with torch.no_grad(): 155 | dt_proj.bias.copy_(inv_dt) 156 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 157 | dt_proj.bias._no_reinit = True 158 | 159 | return dt_proj 160 | 161 | @staticmethod 162 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 163 | # S4D real initialization 164 | A = repeat( 165 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 166 | "n -> d n", 167 | d=d_inner, 168 | ).contiguous() 169 | A_log = torch.log(A) # Keep A_log in fp32 170 | if copies > 1: 171 | A_log = repeat(A_log, "d n -> r d n", r=copies) 172 | if merge: 173 | A_log = A_log.flatten(0, 1) 174 | A_log = nn.Parameter(A_log) 175 | A_log._no_weight_decay = True 176 | return A_log 177 | 178 | @staticmethod 179 | def D_init(d_inner, copies=1, device=None, merge=True): 180 | # D "skip" parameter 181 | D = torch.ones(d_inner, device=device) 182 | if copies > 1: 183 | D = repeat(D, "n1 -> r n1", r=copies) 184 | if merge: 185 | D = D.flatten(0, 1) 186 | D = nn.Parameter(D) # Keep in fp32 187 | D._no_weight_decay = True 188 | return D 189 | 190 | def forward_core(self, x: torch.Tensor): 191 | """ 192 | x: (B, DI, H, W) 193 | 194 | B: batch size 195 | DI: d_inner = d_model * expand 196 | H: height, aka. the number of variates 197 | W: width, aka. sequence length 198 | """ 199 | B, DI, H, W = x.shape 200 | K = self.n_scan_directions 201 | L = H * W 202 | 203 | x = x.view(B, DI, L) # b d_inner l 204 | 205 | # # compute consine similarity 206 | # x_avg = torch.avg_pool1d(x.detach(), kernel_size=L) 207 | # sim = torch.cosine_similarity(x.detach(), x_avg, dim=1) 208 | 209 | # # rearrange scan order 210 | # scan_order = torch.argsort(sim, dim=1) 211 | # scan_order_tiled = scan_order.unsqueeze(2).tile((1, 1, DI)).transpose(1, 2) 212 | # x = torch.gather(x, 2, scan_order_tiled) 213 | 214 | # scan direction, [width right, height down] and [height down, width right] 215 | x_wrhd_hdwr = x.unsqueeze(1) # b 1 d_inner l 216 | 217 | if K == 4: 218 | x_wrhd_hdwr = torch.stack( 219 | [ 220 | x.view(B, -1, L), 221 | torch.transpose(x, 2, 3).contiguous().view(B, DI, L), 222 | ], 223 | dim=1, 224 | ).view( 225 | B, 2, DI, L 226 | ) # (b, 2, d, l) 227 | # scan direction, [width left, height up] and [height up, width left] 228 | x_wlhu_huwl = torch.flip(x_wrhd_hdwr, dims=[-1]) 229 | xs = torch.cat([x_wrhd_hdwr, x_wlhu_huwl], dim=1) # (b, k, d_inner, l) 230 | 231 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) 232 | # xs: (b, k, d_inner, l) 233 | # self.x_proj_weight: (k, dt_rank + d_state*2, d_inner) 234 | # x_dbl: (b, k, dt_rank + d_state*2, l) 235 | dts, Bs, Cs = torch.split( 236 | x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2 237 | ) 238 | # dts: (b, k, dt_rank, l) 239 | # Bs: (b, k, d_state, l) 240 | # Cs: (b, k, d_state, l) 241 | dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) 242 | # self.dt_projs_weight: (k, d_inner, dt_rank) 243 | # dts: (b, k, d_inner, l) 244 | 245 | xs = xs.float().view(B, -1, L) # (b, k*d_inner, l) 246 | dts = dts.float().contiguous().view(B, -1, L) # (b, k*d_inner, l) 247 | As = -torch.exp(self.A_logs.float()) # (k*d_inner, d_state) 248 | Bs = Bs.float() # (b, k, d_state, l) 249 | Cs = Cs.float() # (b, k, d_state, l) 250 | Ds = self.Ds.float() # (k*d_inner) 251 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k*d_inner) 252 | 253 | y = self.selective_scan( 254 | xs, 255 | dts, 256 | As, 257 | Bs, 258 | Cs, 259 | Ds, 260 | delta_bias=dt_projs_bias, 261 | delta_softplus=True, 262 | ).view(B, K, -1, L) 263 | 264 | if K == 2: 265 | y[:, 1] = torch.flip(y[:, 1], dims=[-1]).view(B, -1, L) 266 | elif K == 4: 267 | y[:, 2:4] = torch.flip(y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 268 | y[:, 1] = ( 269 | torch.transpose(y[:, 1].view(B, -1, W, H), 2, 3) 270 | .contiguous() 271 | .view(B, -1, L) 272 | ) 273 | y[:, 3] = ( 274 | torch.transpose(y[:, 3].view(B, -1, W, H), 2, 3) 275 | .contiguous() 276 | .view(B, -1, L) 277 | ) 278 | 279 | y = torch.sum(y, dim=1) # b d_inner l 280 | 281 | # # restore the original order 282 | # inv_scan_order = torch.argsort(scan_order, dim=1) # bs*pn nv 283 | # inv_scan_order_tiled = ( 284 | # inv_scan_order.unsqueeze(2).tile((1, 1, DI)).transpose(1, 2) 285 | # ) 286 | # y = torch.gather(y, 2, inv_scan_order_tiled) # bs*pn nv d_model 287 | 288 | y = y.view(B, -1, H, W) # b d_inner h w 289 | 290 | return y 291 | 292 | def forward(self, x: torch.Tensor, **kwargs): 293 | """ 294 | x: (B, H, W, C) 295 | 296 | B: batch size 297 | H: height, aka. the number of variates 298 | W: width, aka. sequence length 299 | C: channel == d_model 300 | """ 301 | _, H, W, _ = x.size() 302 | 303 | xz = self.in_proj(x) # (b, h, w, d_model) -> (b, h, w, d_inner * 2) 304 | # d_inner: d_model * expand 305 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d_inner) 306 | 307 | x = torch.permute(x, (0, 3, 1, 2)).contiguous() # (b, d_inner, h, w) 308 | x = self.dwconv(x)[..., :H, :W] # (b, d_inner, h, w) 309 | x = self.act(x) 310 | y = self.forward_core(x) # (b, d_inner, h, w) 311 | assert y.dtype == torch.float32 312 | y = torch.permute(y, (0, 2, 3, 1)).contiguous() # (b, h, w, d_inner) 313 | 314 | z = self.act(z) 315 | 316 | out = y * z 317 | out = self.out_proj(out) 318 | out = self.drop(out) 319 | return out 320 | 321 | 322 | class gMlp(nn.Module): 323 | def __init__( 324 | self, 325 | in_features, 326 | hidden_features=None, 327 | out_features=None, 328 | act_layer=nn.GELU, 329 | drop=0.0, 330 | ): 331 | super().__init__() 332 | out_features = out_features or in_features 333 | hidden_features = hidden_features or in_features 334 | 335 | self.fc1 = nn.Linear(in_features, 2 * hidden_features) 336 | self.act = act_layer() 337 | self.fc2 = nn.Linear(hidden_features, out_features) 338 | self.drop = nn.Dropout(drop) 339 | 340 | def forward(self, x: torch.Tensor): 341 | x = self.fc1(x) 342 | x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) 343 | x = self.fc2(x * self.act(z)) 344 | x = self.drop(x) 345 | return x 346 | 347 | 348 | class Mlp(nn.Module): 349 | def __init__( 350 | self, 351 | in_features, 352 | hidden_features=None, 353 | out_features=None, 354 | act_layer=nn.GELU, 355 | drop=0.0, 356 | ): 357 | super().__init__() 358 | out_features = out_features or in_features 359 | hidden_features = hidden_features or in_features 360 | 361 | self.fc1 = nn.Linear(in_features, hidden_features) 362 | self.act = act_layer() 363 | self.fc2 = nn.Linear(hidden_features, out_features) 364 | self.drop = nn.Dropout(drop) 365 | 366 | def forward(self, x): 367 | x = self.drop(self.act(self.fc1(x))) 368 | x = self.drop(self.fc2(x)) 369 | return x 370 | 371 | 372 | class MambaformerLayer(nn.Module): 373 | def __init__( 374 | self, 375 | hidden_dim: int, 376 | drop_path_rate: float = 0.1, 377 | norm_layer: nn.Module = nn.LayerNorm, 378 | # ==== ssm 379 | ssm_d_state: int = 16, 380 | ssm_expand: int = 1, 381 | ssm_conv: int = 3, 382 | ssm_directions: int = 2, 383 | ssm_drop_rate: float = 0.1, 384 | # ==== mlp 385 | mlp_ratio: float = 4.0, 386 | mlp_act_layer: nn.Module = nn.GELU, 387 | mlp_drop_rate: float = 0.1, 388 | gmlp: bool = False, 389 | **kwargs, 390 | ): 391 | super().__init__() 392 | 393 | self.mamba_branch = ssm_expand > 0 394 | 395 | if self.mamba_branch: 396 | self.norm = norm_layer(hidden_dim) 397 | self.mamba = QuadMamba( 398 | d_model=hidden_dim, 399 | expand=ssm_expand, 400 | d_state=ssm_d_state, 401 | d_conv=ssm_conv, 402 | n_scan_directions=ssm_directions, 403 | dropout=ssm_drop_rate, 404 | **kwargs, 405 | ) 406 | 407 | self.mlp_branch = mlp_ratio > 0 408 | 409 | if self.mlp_branch: 410 | _MLP = Mlp if not gmlp else gMlp 411 | self.norm2 = norm_layer(hidden_dim) 412 | mlp_hidden_dim = int(hidden_dim * mlp_ratio) 413 | self.mlp = _MLP( 414 | in_features=hidden_dim, 415 | hidden_features=mlp_hidden_dim, 416 | act_layer=mlp_act_layer, 417 | drop=mlp_drop_rate, 418 | ) 419 | 420 | if self.mamba_branch or self.mlp_branch: 421 | self.drop_path_ly = DropPath(drop_path_rate) 422 | 423 | def forward(self, input: torch.Tensor): 424 | # input: (b, h, w, c) 425 | 426 | x = input 427 | 428 | if self.mamba_branch: 429 | x = x + self.drop_path_ly(self.mamba(self.norm(input))) # SSM 430 | 431 | if self.mlp_branch: 432 | x = x + self.drop_path_ly(self.mlp(self.norm2(x))) # FFN 433 | 434 | return x 435 | 436 | 437 | class Mambaformer(nn.Module): 438 | """A basic Mamba layer for one stage. 439 | Args: 440 | dim (int): Number of input channels. 441 | depth (int): Number of blocks. 442 | attn_drop (float, optional): Attention dropout rate. Default: 0.1 443 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.1 444 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 445 | """ 446 | 447 | def __init__( 448 | self, 449 | dim, 450 | depth, 451 | expand=2, 452 | ssm_drop_rate=0.1, 453 | drop_path_rate=0.1, 454 | norm_layer=nn.LayerNorm, 455 | ssm_d_state=16, 456 | ssm_conv=3, 457 | ssm_directions=2, 458 | mlp_ratio=0, 459 | mlp_drop_rate=0.1, 460 | **kwargs, 461 | ): 462 | super().__init__() 463 | 464 | self.blocks = nn.ModuleList( 465 | [ 466 | MambaformerLayer( 467 | dim, 468 | ssm_expand=expand, 469 | ssm_d_state=ssm_d_state, 470 | ssm_conv=ssm_conv, 471 | ssm_directions=ssm_directions, 472 | ssm_drop_rate=ssm_drop_rate, 473 | mlp_ratio=mlp_ratio, 474 | mlp_drop_rate=mlp_drop_rate, 475 | drop_path_rate=( 476 | drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 477 | ), 478 | norm_layer=norm_layer, 479 | **kwargs, 480 | ) 481 | for i in range(depth) 482 | ] 483 | ) 484 | 485 | def forward(self, x): 486 | # x: (b, h, w, c), l: seq_len 487 | 488 | for blk in self.blocks: 489 | x = blk(x) # (b, h, w, c) 490 | 491 | return x 492 | 493 | 494 | class MambaForSeriesForecasting(nn.Module): 495 | def __init__( 496 | self, 497 | dims=[768], 498 | depths=[4], 499 | ssm_expand=2, 500 | ssm_d_state=16, 501 | ssm_conv=3, 502 | ssm_directions=2, 503 | ssm_drop_rate=0.1, 504 | mlp_ratio=4, 505 | mlp_drop_rate=0.1, 506 | drop_path_rate=0.1, 507 | norm_layer=nn.LayerNorm, 508 | ): 509 | super().__init__() 510 | self.num_layers = len(depths) 511 | 512 | # stochastic depth decay rule 513 | dpr = torch.linspace(0, drop_path_rate, sum(depths) + 1)[1:].tolist() 514 | 515 | # build layers 516 | self.layers = nn.ModuleList() 517 | for i_layer in range(self.num_layers): 518 | layer = Mambaformer( 519 | dim=dims[i_layer], 520 | depth=depths[i_layer], 521 | expand=ssm_expand, 522 | ssm_d_state=ssm_d_state, 523 | ssm_conv=ssm_conv, 524 | ssm_directions=ssm_directions, 525 | ssm_drop_rate=ssm_drop_rate, 526 | mlp_ratio=mlp_ratio, 527 | mlp_drop_rate=mlp_drop_rate, 528 | drop_path_rate=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], 529 | norm_layer=norm_layer, 530 | ) 531 | self.layers.append(layer) 532 | 533 | self.norm = norm_layer(dims[-1]) 534 | 535 | self.apply(self._init_weights) 536 | 537 | def _init_weights(self, m: nn.Module): 538 | if isinstance(m, nn.Linear): 539 | nn.init.trunc_normal_(m.weight, std=0.02) 540 | if isinstance(m, nn.Linear) and m.bias is not None: 541 | nn.init.constant_(m.bias, 0) 542 | elif isinstance(m, nn.LayerNorm): 543 | nn.init.constant_(m.bias, 0) 544 | nn.init.constant_(m.weight, 1.0) 545 | 546 | def forward(self, x): 547 | # x: (b, h, w, c) 548 | 549 | for layer in self.layers: 550 | x = layer(x) # b h w c 551 | 552 | x = self.norm(x) # b h w c 553 | 554 | return x 555 | -------------------------------------------------------------------------------- /data_provider/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from sklearn.preprocessing import StandardScaler 7 | from utils.timefeatures import time_features 8 | import warnings 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | 13 | class Dataset_ETT_hour(Dataset): 14 | def __init__(self, root_path, flag='train', size=None, 15 | features='S', data_path='ETTh1.csv', 16 | target='OT', scale=True, timeenc=0, freq='h'): 17 | # size [seq_len, label_len, pred_len] 18 | # info 19 | if size == None: 20 | self.seq_len = 24 * 4 * 4 21 | self.label_len = 24 * 4 22 | self.pred_len = 24 * 4 23 | else: 24 | self.seq_len = size[0] 25 | self.label_len = size[1] 26 | self.pred_len = size[2] 27 | # init 28 | assert flag in ['train', 'test', 'val'] 29 | type_map = {'train': 0, 'val': 1, 'test': 2} 30 | self.set_type = type_map[flag] 31 | 32 | self.features = features 33 | self.target = target 34 | self.scale = scale 35 | self.timeenc = timeenc 36 | self.freq = freq 37 | 38 | self.root_path = root_path 39 | self.data_path = data_path 40 | self.__read_data__() 41 | 42 | def __read_data__(self): 43 | self.scaler = StandardScaler() 44 | df_raw = pd.read_csv(os.path.join(self.root_path, 45 | self.data_path)) 46 | 47 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] 48 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] 49 | border1 = border1s[self.set_type] 50 | border2 = border2s[self.set_type] 51 | 52 | if self.features == 'M' or self.features == 'MS': 53 | cols_data = df_raw.columns[1:] 54 | df_data = df_raw[cols_data] 55 | elif self.features == 'S': 56 | df_data = df_raw[[self.target]] 57 | 58 | if self.scale: 59 | train_data = df_data[border1s[0]:border2s[0]] 60 | self.scaler.fit(train_data.values) 61 | data = self.scaler.transform(df_data.values) 62 | else: 63 | data = df_data.values 64 | 65 | df_stamp = df_raw[['date']][border1:border2] 66 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 67 | if self.timeenc == 0: 68 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 69 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 70 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 71 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 72 | data_stamp = df_stamp.drop(['date'], 1).values 73 | elif self.timeenc == 1: 74 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 75 | data_stamp = data_stamp.transpose(1, 0) 76 | 77 | self.data_x = data[border1:border2] 78 | self.data_y = data[border1:border2] 79 | self.data_stamp = data_stamp 80 | 81 | def __getitem__(self, index): 82 | s_begin = index 83 | s_end = s_begin + self.seq_len 84 | r_begin = s_end - self.label_len 85 | r_end = r_begin + self.label_len + self.pred_len 86 | 87 | seq_x = self.data_x[s_begin:s_end] 88 | seq_y = self.data_y[r_begin:r_end] 89 | seq_x_mark = self.data_stamp[s_begin:s_end] 90 | seq_y_mark = self.data_stamp[r_begin:r_end] 91 | 92 | return seq_x, seq_y, seq_x_mark, seq_y_mark 93 | 94 | def __len__(self): 95 | return len(self.data_x) - self.seq_len - self.pred_len + 1 96 | 97 | def inverse_transform(self, data): 98 | return self.scaler.inverse_transform(data) 99 | 100 | 101 | class Dataset_ETT_minute(Dataset): 102 | def __init__(self, root_path, flag='train', size=None, 103 | features='S', data_path='ETTm1.csv', 104 | target='OT', scale=True, timeenc=0, freq='t'): 105 | # size [seq_len, label_len, pred_len] 106 | # info 107 | if size == None: 108 | self.seq_len = 24 * 4 * 4 109 | self.label_len = 24 * 4 110 | self.pred_len = 24 * 4 111 | else: 112 | self.seq_len = size[0] 113 | self.label_len = size[1] 114 | self.pred_len = size[2] 115 | # init 116 | assert flag in ['train', 'test', 'val'] 117 | type_map = {'train': 0, 'val': 1, 'test': 2} 118 | self.set_type = type_map[flag] 119 | 120 | self.features = features 121 | self.target = target 122 | self.scale = scale 123 | self.timeenc = timeenc 124 | self.freq = freq 125 | 126 | self.root_path = root_path 127 | self.data_path = data_path 128 | self.__read_data__() 129 | 130 | def __read_data__(self): 131 | self.scaler = StandardScaler() 132 | df_raw = pd.read_csv(os.path.join(self.root_path, 133 | self.data_path)) 134 | 135 | border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len] 136 | border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4] 137 | border1 = border1s[self.set_type] 138 | border2 = border2s[self.set_type] 139 | 140 | if self.features == 'M' or self.features == 'MS': 141 | cols_data = df_raw.columns[1:] 142 | df_data = df_raw[cols_data] 143 | elif self.features == 'S': 144 | df_data = df_raw[[self.target]] 145 | 146 | if self.scale: 147 | train_data = df_data[border1s[0]:border2s[0]] 148 | self.scaler.fit(train_data.values) 149 | data = self.scaler.transform(df_data.values) 150 | else: 151 | data = df_data.values 152 | 153 | df_stamp = df_raw[['date']][border1:border2] 154 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 155 | if self.timeenc == 0: 156 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 157 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 158 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 159 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 160 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 161 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 162 | data_stamp = df_stamp.drop(['date'], 1).values 163 | elif self.timeenc == 1: 164 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 165 | data_stamp = data_stamp.transpose(1, 0) 166 | 167 | self.data_x = data[border1:border2] 168 | self.data_y = data[border1:border2] 169 | self.data_stamp = data_stamp 170 | 171 | def __getitem__(self, index): 172 | s_begin = index 173 | s_end = s_begin + self.seq_len 174 | r_begin = s_end - self.label_len 175 | r_end = r_begin + self.label_len + self.pred_len 176 | 177 | seq_x = self.data_x[s_begin:s_end] 178 | seq_y = self.data_y[r_begin:r_end] 179 | seq_x_mark = self.data_stamp[s_begin:s_end] 180 | seq_y_mark = self.data_stamp[r_begin:r_end] 181 | 182 | return seq_x, seq_y, seq_x_mark, seq_y_mark 183 | 184 | def __len__(self): 185 | return len(self.data_x) - self.seq_len - self.pred_len + 1 186 | 187 | def inverse_transform(self, data): 188 | return self.scaler.inverse_transform(data) 189 | 190 | 191 | class Dataset_Custom(Dataset): 192 | def __init__(self, root_path, flag='train', size=None, 193 | features='S', data_path='ETTh1.csv', 194 | target='OT', scale=True, timeenc=0, freq='h'): 195 | # size [seq_len, label_len, pred_len] 196 | # info 197 | if size == None: 198 | self.seq_len = 24 * 4 * 4 199 | self.label_len = 24 * 4 200 | self.pred_len = 24 * 4 201 | else: 202 | self.seq_len = size[0] 203 | self.label_len = size[1] 204 | self.pred_len = size[2] 205 | # init 206 | assert flag in ['train', 'test', 'val'] 207 | type_map = {'train': 0, 'val': 1, 'test': 2} 208 | self.set_type = type_map[flag] 209 | 210 | self.features = features 211 | self.target = target 212 | self.scale = scale 213 | self.timeenc = timeenc 214 | self.freq = freq 215 | 216 | self.root_path = root_path 217 | self.data_path = data_path 218 | self.__read_data__() 219 | 220 | def __read_data__(self): 221 | self.scaler = StandardScaler() 222 | df_raw = pd.read_csv(os.path.join(self.root_path, 223 | self.data_path)) 224 | 225 | ''' 226 | df_raw.columns: ['date', ...(other features), target feature] 227 | ''' 228 | cols = list(df_raw.columns) 229 | cols.remove(self.target) 230 | cols.remove('date') 231 | df_raw = df_raw[['date'] + cols + [self.target]] 232 | num_train = int(len(df_raw) * 0.7) 233 | num_test = int(len(df_raw) * 0.2) 234 | num_vali = len(df_raw) - num_train - num_test 235 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 236 | border2s = [num_train, num_train + num_vali, len(df_raw)] 237 | border1 = border1s[self.set_type] 238 | border2 = border2s[self.set_type] 239 | 240 | if self.features == 'M' or self.features == 'MS': 241 | cols_data = df_raw.columns[1:] 242 | df_data = df_raw[cols_data] 243 | elif self.features == 'S': 244 | df_data = df_raw[[self.target]] 245 | 246 | if self.scale: 247 | train_data = df_data[border1s[0]:border2s[0]] 248 | self.scaler.fit(train_data.values) 249 | data = self.scaler.transform(df_data.values) 250 | else: 251 | data = df_data.values 252 | 253 | df_stamp = df_raw[['date']][border1:border2] 254 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 255 | if self.timeenc == 0: 256 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 257 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 258 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 259 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 260 | data_stamp = df_stamp.drop(['date'], 1).values 261 | elif self.timeenc == 1: 262 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 263 | data_stamp = data_stamp.transpose(1, 0) 264 | 265 | self.data_x = data[border1:border2] 266 | self.data_y = data[border1:border2] 267 | self.data_stamp = data_stamp 268 | 269 | def __getitem__(self, index): 270 | s_begin = index 271 | s_end = s_begin + self.seq_len 272 | r_begin = s_end - self.label_len 273 | r_end = r_begin + self.label_len + self.pred_len 274 | 275 | seq_x = self.data_x[s_begin:s_end] 276 | seq_y = self.data_y[r_begin:r_end] 277 | seq_x_mark = self.data_stamp[s_begin:s_end] 278 | seq_y_mark = self.data_stamp[r_begin:r_end] 279 | 280 | return seq_x, seq_y, seq_x_mark, seq_y_mark 281 | 282 | def __len__(self): 283 | return len(self.data_x) - self.seq_len - self.pred_len + 1 284 | 285 | def inverse_transform(self, data): 286 | return self.scaler.inverse_transform(data) 287 | 288 | 289 | class Dataset_PEMS(Dataset): 290 | def __init__(self, root_path, flag='train', size=None, 291 | features='S', data_path='ETTh1.csv', 292 | target='OT', scale=True, timeenc=0, freq='h'): 293 | # size [seq_len, label_len, pred_len] 294 | # info 295 | self.seq_len = size[0] 296 | self.label_len = size[1] 297 | self.pred_len = size[2] 298 | # init 299 | assert flag in ['train', 'test', 'val'] 300 | type_map = {'train': 0, 'val': 1, 'test': 2} 301 | self.set_type = type_map[flag] 302 | 303 | self.features = features 304 | self.target = target 305 | self.scale = scale 306 | self.timeenc = timeenc 307 | self.freq = freq 308 | 309 | self.root_path = root_path 310 | self.data_path = data_path 311 | self.__read_data__() 312 | 313 | def __read_data__(self): 314 | self.scaler = StandardScaler() 315 | data_file = os.path.join(self.root_path, self.data_path) 316 | data = np.load(data_file, allow_pickle=True) 317 | data = data['data'][:, :, 0] 318 | 319 | train_ratio = 0.6 320 | valid_ratio = 0.2 321 | train_data = data[:int(train_ratio * len(data))] 322 | valid_data = data[int(train_ratio * len(data)): int((train_ratio + valid_ratio) * len(data))] 323 | test_data = data[int((train_ratio + valid_ratio) * len(data)):] 324 | total_data = [train_data, valid_data, test_data] 325 | data = total_data[self.set_type] 326 | 327 | if self.scale: 328 | self.scaler.fit(train_data) 329 | data = self.scaler.transform(data) 330 | 331 | df = pd.DataFrame(data) 332 | df = df.fillna(method='ffill', limit=len(df)).fillna(method='bfill', limit=len(df)).values 333 | 334 | self.data_x = df 335 | self.data_y = df 336 | 337 | def __getitem__(self, index): 338 | s_begin = index 339 | s_end = s_begin + self.seq_len 340 | r_begin = s_end - self.label_len 341 | r_end = r_begin + self.label_len + self.pred_len 342 | 343 | seq_x = self.data_x[s_begin:s_end] 344 | seq_y = self.data_y[r_begin:r_end] 345 | seq_x_mark = torch.zeros((seq_x.shape[0], 1)) 346 | seq_y_mark = torch.zeros((seq_x.shape[0], 1)) 347 | 348 | return seq_x, seq_y, seq_x_mark, seq_y_mark 349 | 350 | def __len__(self): 351 | return len(self.data_x) - self.seq_len - self.pred_len + 1 352 | 353 | def inverse_transform(self, data): 354 | return self.scaler.inverse_transform(data) 355 | 356 | 357 | class Dataset_Solar(Dataset): 358 | def __init__(self, root_path, flag='train', size=None, 359 | features='S', data_path='ETTh1.csv', 360 | target='OT', scale=True, timeenc=0, freq='h'): 361 | # size [seq_len, label_len, pred_len] 362 | # info 363 | self.seq_len = size[0] 364 | self.label_len = size[1] 365 | self.pred_len = size[2] 366 | # init 367 | assert flag in ['train', 'test', 'val'] 368 | type_map = {'train': 0, 'val': 1, 'test': 2} 369 | self.set_type = type_map[flag] 370 | 371 | self.features = features 372 | self.target = target 373 | self.scale = scale 374 | self.timeenc = timeenc 375 | self.freq = freq 376 | 377 | self.root_path = root_path 378 | self.data_path = data_path 379 | self.__read_data__() 380 | 381 | def __read_data__(self): 382 | self.scaler = StandardScaler() 383 | df_raw = [] 384 | with open(os.path.join(self.root_path, self.data_path), "r", encoding='utf-8') as f: 385 | for line in f.readlines(): 386 | line = line.strip('\n').split(',') 387 | data_line = np.stack([float(i) for i in line]) 388 | df_raw.append(data_line) 389 | df_raw = np.stack(df_raw, 0) 390 | df_raw = pd.DataFrame(df_raw) 391 | 392 | num_train = int(len(df_raw) * 0.7) 393 | num_test = int(len(df_raw) * 0.2) 394 | num_valid = int(len(df_raw) * 0.1) 395 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 396 | border2s = [num_train, num_train + num_valid, len(df_raw)] 397 | border1 = border1s[self.set_type] 398 | border2 = border2s[self.set_type] 399 | 400 | df_data = df_raw.values 401 | 402 | if self.scale: 403 | train_data = df_data[border1s[0]:border2s[0]] 404 | self.scaler.fit(train_data) 405 | data = self.scaler.transform(df_data) 406 | else: 407 | data = df_data 408 | 409 | self.data_x = data[border1:border2] 410 | self.data_y = data[border1:border2] 411 | 412 | def __getitem__(self, index): 413 | s_begin = index 414 | s_end = s_begin + self.seq_len 415 | r_begin = s_end - self.label_len 416 | r_end = r_begin + self.label_len + self.pred_len 417 | 418 | seq_x = self.data_x[s_begin:s_end] 419 | seq_y = self.data_y[r_begin:r_end] 420 | seq_x_mark = torch.zeros((seq_x.shape[0], 1)) 421 | seq_y_mark = torch.zeros((seq_x.shape[0], 1)) 422 | 423 | return seq_x, seq_y, seq_x_mark, seq_y_mark 424 | 425 | def __len__(self): 426 | return len(self.data_x) - self.seq_len - self.pred_len + 1 427 | 428 | def inverse_transform(self, data): 429 | return self.scaler.inverse_transform(data) 430 | 431 | 432 | class Dataset_Pred(Dataset): 433 | def __init__(self, root_path, flag='pred', size=None, 434 | features='S', data_path='ETTh1.csv', 435 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None): 436 | # size [seq_len, label_len, pred_len] 437 | # info 438 | if size == None: 439 | self.seq_len = 24 * 4 * 4 440 | self.label_len = 24 * 4 441 | self.pred_len = 24 * 4 442 | else: 443 | self.seq_len = size[0] 444 | self.label_len = size[1] 445 | self.pred_len = size[2] 446 | # init 447 | assert flag in ['pred'] 448 | 449 | self.features = features 450 | self.target = target 451 | self.scale = scale 452 | self.inverse = inverse 453 | self.timeenc = timeenc 454 | self.freq = freq 455 | self.cols = cols 456 | self.root_path = root_path 457 | self.data_path = data_path 458 | self.__read_data__() 459 | 460 | def __read_data__(self): 461 | self.scaler = StandardScaler() 462 | df_raw = pd.read_csv(os.path.join(self.root_path, 463 | self.data_path)) 464 | ''' 465 | df_raw.columns: ['date', ...(other features), target feature] 466 | ''' 467 | if self.cols: 468 | cols = self.cols.copy() 469 | cols.remove(self.target) 470 | else: 471 | cols = list(df_raw.columns) 472 | cols.remove(self.target) 473 | cols.remove('date') 474 | df_raw = df_raw[['date'] + cols + [self.target]] 475 | border1 = len(df_raw) - self.seq_len 476 | border2 = len(df_raw) 477 | 478 | if self.features == 'M' or self.features == 'MS': 479 | cols_data = df_raw.columns[1:] 480 | df_data = df_raw[cols_data] 481 | elif self.features == 'S': 482 | df_data = df_raw[[self.target]] 483 | 484 | if self.scale: 485 | self.scaler.fit(df_data.values) 486 | data = self.scaler.transform(df_data.values) 487 | else: 488 | data = df_data.values 489 | 490 | tmp_stamp = df_raw[['date']][border1:border2] 491 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date) 492 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq) 493 | 494 | df_stamp = pd.DataFrame(columns=['date']) 495 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:]) 496 | if self.timeenc == 0: 497 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 498 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 499 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 500 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 501 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 502 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 503 | data_stamp = df_stamp.drop(['date'], 1).values 504 | elif self.timeenc == 1: 505 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 506 | data_stamp = data_stamp.transpose(1, 0) 507 | 508 | self.data_x = data[border1:border2] 509 | if self.inverse: 510 | self.data_y = df_data.values[border1:border2] 511 | else: 512 | self.data_y = data[border1:border2] 513 | self.data_stamp = data_stamp 514 | 515 | def __getitem__(self, index): 516 | s_begin = index 517 | s_end = s_begin + self.seq_len 518 | r_begin = s_end - self.label_len 519 | r_end = r_begin + self.label_len + self.pred_len 520 | 521 | seq_x = self.data_x[s_begin:s_end] 522 | if self.inverse: 523 | seq_y = self.data_x[r_begin:r_begin + self.label_len] 524 | else: 525 | seq_y = self.data_y[r_begin:r_begin + self.label_len] 526 | seq_x_mark = self.data_stamp[s_begin:s_end] 527 | seq_y_mark = self.data_stamp[r_begin:r_end] 528 | 529 | return seq_x, seq_y, seq_x_mark, seq_y_mark 530 | 531 | def __len__(self): 532 | return len(self.data_x) - self.seq_len + 1 533 | 534 | def inverse_transform(self, data): 535 | return self.scaler.inverse_transform(data) 536 | --------------------------------------------------------------------------------