├── timeprophet ├── __init__.py ├── experiments │ ├── __init__.py │ └── forecasting.py ├── data_modules │ ├── __init__.py │ ├── multivar.py │ ├── ETDataset.py │ └── base.py ├── models │ ├── __init__.py │ ├── NLinear.py │ ├── FITS.py │ ├── DLinear.py │ └── DiPE.py ├── logger │ └── __init__.py └── utils │ └── callbacks.py ├── dataset └── .gitignore ├── .style.yapf ├── configs ├── gpu.yaml ├── datasets │ ├── ETTh1 │ │ ├── 720_96.yaml │ │ ├── 720_192.yaml │ │ ├── 720_336.yaml │ │ └── 720_720.yaml │ ├── ETTh2 │ │ ├── 720_96.yaml │ │ ├── 720_192.yaml │ │ ├── 720_336.yaml │ │ └── 720_720.yaml │ ├── ETTm1 │ │ ├── 720_96.yaml │ │ ├── 720_192.yaml │ │ ├── 720_336.yaml │ │ └── 720_720.yaml │ ├── ETTm2 │ │ ├── 720_96.yaml │ │ ├── 720_192.yaml │ │ ├── 720_336.yaml │ │ └── 720_720.yaml │ ├── M5 │ │ ├── 60_24.yaml │ │ ├── 60_36.yaml │ │ ├── 60_48.yaml │ │ └── 60_60.yaml │ ├── Illness │ │ ├── 60_24.yaml │ │ ├── 60_36.yaml │ │ ├── 60_48.yaml │ │ └── 60_60.yaml │ ├── Traffic │ │ ├── 720_192.yaml │ │ ├── 720_336.yaml │ │ ├── 720_720.yaml │ │ └── 720_96.yaml │ ├── Weather │ │ ├── 720_192.yaml │ │ ├── 720_336.yaml │ │ ├── 720_720.yaml │ │ └── 720_96.yaml │ └── Electricity │ │ ├── 720_96.yaml │ │ ├── 720_192.yaml │ │ ├── 720_336.yaml │ │ └── 720_720.yaml └── models │ ├── DiPE │ └── base.yaml │ ├── FITS │ └── base.yaml │ ├── DLinear │ └── base.yaml │ ├── NLinear │ └── base.yaml │ └── RLinear │ └── base.yaml ├── .prettierrc ├── .github └── ISSUE_TEMPLATE │ └── config.yml ├── preprocessing └── M5.py ├── LICENSE ├── scripts ├── DLinear.sh ├── NLinear.sh ├── RLinear.sh ├── FITS.sh └── DiPE.sh ├── main.py ├── .gitignore ├── README.md ├── environment.yml └── .pylintrc /timeprophet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | -------------------------------------------------------------------------------- /configs/gpu.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | init_args: 3 | gpu: True 4 | -------------------------------------------------------------------------------- /timeprophet/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .forecasting import LongTermForecasting 2 | 3 | __all__ = ['LongTermForecasting'] 4 | -------------------------------------------------------------------------------- /timeprophet/data_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .ETDataset import ETDataModule 2 | from .multivar import MultivarDataModule 3 | 4 | __all__ = ['ETDataModule', 'MultivarDataModule'] 5 | -------------------------------------------------------------------------------- /timeprophet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .DiPE import DiPE 2 | from .DLinear import DLinear 3 | from .FITS import FITS 4 | from .NLinear import NLinear 5 | 6 | __all__ = ['DLinear', 'FITS', 'DiPE', 'NLinear'] 7 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "overrides": [ 3 | { 4 | "files": "configs/**/*.yaml", 5 | "options": { 6 | "tabWidth": 2, 7 | "useTabs": false 8 | } 9 | } 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /configs/datasets/ETTh1/720_96.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 96 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTh2/720_96.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 96 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm1/720_96.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 96 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm2/720_96.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 96 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/M5/60_24.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/m5.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 24 8 | x_features: null 9 | y_features: null 10 | all_features_num: 30 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/M5/60_36.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/m5.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 36 8 | x_features: null 9 | y_features: null 10 | all_features_num: 30 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/M5/60_48.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/m5.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 48 8 | x_features: null 9 | y_features: null 10 | all_features_num: 30 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/M5/60_60.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/m5.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 60 8 | x_features: null 9 | y_features: null 10 | all_features_num: 30 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTh1/720_192.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 192 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTh1/720_336.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 336 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTh1/720_720.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 720 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTh2/720_192.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 192 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTh2/720_336.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 336 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTh2/720_720.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTh2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 720 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm1/720_192.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 192 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm1/720_336.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 336 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm1/720_720.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm1.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 720 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm2/720_192.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 192 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm2/720_336.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 336 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/ETTm2/720_720.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.ETDataModule 3 | init_args: 4 | dataset_path: dataset/ETTm2.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 720 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Illness/60_24.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/illness.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 24 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Illness/60_36.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/illness.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 36 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Illness/60_48.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/illness.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 48 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Illness/60_60.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/illness.csv.gz 5 | batch_size: 8 6 | input_len: 60 7 | output_len: 60 8 | x_features: null 9 | y_features: null 10 | all_features_num: 7 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Traffic/720_192.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/traffic.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 192 8 | x_features: null 9 | y_features: null 10 | all_features_num: 862 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Traffic/720_336.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/traffic.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 336 8 | x_features: null 9 | y_features: null 10 | all_features_num: 862 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Traffic/720_720.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/traffic.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 720 8 | x_features: null 9 | y_features: null 10 | all_features_num: 862 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Traffic/720_96.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/traffic.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 96 8 | x_features: null 9 | y_features: null 10 | all_features_num: 862 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Weather/720_192.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/weather.csv.gz 5 | batch_size: 128 6 | input_len: 720 7 | output_len: 192 8 | x_features: null 9 | y_features: null 10 | all_features_num: 21 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Weather/720_336.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/weather.csv.gz 5 | batch_size: 128 6 | input_len: 720 7 | output_len: 336 8 | x_features: null 9 | y_features: null 10 | all_features_num: 21 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Weather/720_720.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/weather.csv.gz 5 | batch_size: 128 6 | input_len: 720 7 | output_len: 720 8 | x_features: null 9 | y_features: null 10 | all_features_num: 21 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Weather/720_96.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/weather.csv.gz 5 | batch_size: 128 6 | input_len: 720 7 | output_len: 96 8 | x_features: null 9 | y_features: null 10 | all_features_num: 21 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Electricity/720_96.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/electricity.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 96 8 | x_features: null 9 | y_features: null 10 | all_features_num: 321 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Electricity/720_192.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/electricity.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 192 8 | x_features: null 9 | y_features: null 10 | all_features_num: 321 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Electricity/720_336.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/electricity.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 336 8 | x_features: null 9 | y_features: null 10 | all_features_num: 321 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /configs/datasets/Electricity/720_720.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: timeprophet.data_modules.MultivarDataModule 3 | init_args: 4 | dataset_path: dataset/electricity.csv.gz 5 | batch_size: 64 6 | input_len: 720 7 | output_len: 720 8 | x_features: null 9 | y_features: null 10 | all_features_num: 321 11 | preprocessor: sklearn.preprocessing.StandardScaler 12 | pin_memory: true 13 | num_workers: 4 14 | persistent_workers: true 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Q & A 4 | url: https://github.com/wintertee/DiPE-Linear/discussions/new?category=q-a 5 | about: Questions if you need help. 6 | - name: Ideas 7 | url: https://github.com/wintertee/DiPE-Linear/discussions/new?category=ideas 8 | about: For feature requests. 9 | - name: Bugs report 10 | url: https://github.com/wintertee/DiPE-Linear/discussions/new?category=general 11 | about: For suspected bugs or weird things. 12 | -------------------------------------------------------------------------------- /timeprophet/logger/__init__.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning.loggers 2 | 3 | __all__ = ["TensorBoardLogger"] 4 | 5 | 6 | class TensorBoardLogger(pytorch_lightning.loggers.TensorBoardLogger): 7 | 8 | def __init__(self, save_dir, task: str, model: str, dataset: str, 9 | input_length: int, output_length: int, *args, **kwargs): 10 | super().__init__( 11 | save_dir, "/".join( 12 | [task, dataset, 13 | str(input_length), 14 | str(output_length), model]), *args, **kwargs) 15 | -------------------------------------------------------------------------------- /preprocessing/M5.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | train = pd.read_csv('sales_train_validation.csv') 4 | test = pd.read_csv('sales_test_validation.csv') 5 | 6 | df = pd.merge(train, 7 | test, 8 | on=['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'], 9 | how='inner') 10 | 11 | df = df.drop(columns=['item_id', 'dept_id', 'state_id']).groupby( 12 | ['cat_id', 'store_id']).sum().reset_index() 13 | 14 | df['cat_id_store_id'] = df['cat_id'] + '_' + df['store_id'] 15 | 16 | df = df.set_index('cat_id_store_id').drop(['cat_id', 'store_id'], axis=1).T 17 | 18 | df.insert(0, 'date', range(1, len(df) + 1)) 19 | 20 | df.to_csv('m5.csv.gz', index=False) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 WinterTee and DiPE-Linear authors. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/models/DiPE/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: timeprophet.experiments.forecasting.LongTermForecasting 3 | init_args: 4 | model: 5 | class_path: timeprophet.models.DiPE 6 | init_args: 7 | use_revin: True 8 | log_forecast: False 9 | 10 | optimizer: 11 | class_path: torch.optim.Adam 12 | init_args: 13 | lr: 1e-3 14 | 15 | lr_scheduler: 16 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 17 | init_args: 18 | T_max: 50 19 | eta_min: 0 20 | 21 | trainer: 22 | max_epochs: 50 23 | precision: 32 24 | logger: 25 | - class_path: timeprophet.logger.TensorBoardLogger 26 | init_args: 27 | save_dir: logs 28 | callbacks: 29 | - class_path: timeprophet.utils.callbacks.TemperatureScaling 30 | init_args: 31 | verbose: True 32 | - class_path: pytorch_lightning.callbacks.RichModelSummary 33 | init_args: 34 | max_depth: -1 35 | - class_path: pytorch_lightning.callbacks.RichProgressBar 36 | init_args: 37 | refresh_rate: 1 38 | leave: True 39 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 40 | init_args: 41 | monitor: val_loss 42 | verbose: False 43 | save_top_k: 1 44 | mode: min 45 | save_last: True 46 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 47 | init_args: 48 | logging_interval: epoch 49 | log_every_n_steps: 10 50 | -------------------------------------------------------------------------------- /configs/models/FITS/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: timeprophet.experiments.forecasting.LongTermForecasting 3 | init_args: 4 | model: 5 | class_path: timeprophet.models.FITS 6 | init_args: 7 | individual: False 8 | log_forecast: False 9 | 10 | optimizer: 11 | class_path: torch.optim.Adam 12 | init_args: 13 | lr: 1e-3 14 | 15 | lr_scheduler: 16 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 17 | init_args: 18 | T_max: 50 19 | eta_min: 0 20 | 21 | trainer: 22 | max_epochs: 50 23 | precision: 32 24 | logger: 25 | - class_path: timeprophet.logger.TensorBoardLogger 26 | init_args: 27 | save_dir: logs 28 | callbacks: 29 | - class_path: timeprophet.utils.callbacks.TemperatureScaling 30 | init_args: 31 | verbose: True 32 | - class_path: pytorch_lightning.callbacks.RichModelSummary 33 | init_args: 34 | max_depth: -1 35 | - class_path: pytorch_lightning.callbacks.RichProgressBar 36 | init_args: 37 | refresh_rate: 1 38 | leave: True 39 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 40 | init_args: 41 | monitor: val_loss 42 | verbose: False 43 | save_top_k: 1 44 | mode: min 45 | save_last: True 46 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 47 | init_args: 48 | logging_interval: epoch 49 | log_every_n_steps: 10 50 | -------------------------------------------------------------------------------- /configs/models/DLinear/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: timeprophet.experiments.forecasting.LongTermForecasting 3 | init_args: 4 | model: 5 | class_path: timeprophet.models.DLinear 6 | init_args: 7 | individual: False 8 | log_forecast: False 9 | 10 | optimizer: 11 | class_path: torch.optim.Adam 12 | init_args: 13 | lr: 1e-3 14 | 15 | lr_scheduler: 16 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 17 | init_args: 18 | T_max: 50 19 | eta_min: 0 20 | 21 | trainer: 22 | max_epochs: 50 23 | precision: 32 24 | logger: 25 | - class_path: timeprophet.logger.TensorBoardLogger 26 | init_args: 27 | save_dir: logs 28 | callbacks: 29 | - class_path: timeprophet.utils.callbacks.TemperatureScaling 30 | init_args: 31 | verbose: True 32 | - class_path: pytorch_lightning.callbacks.RichModelSummary 33 | init_args: 34 | max_depth: -1 35 | - class_path: pytorch_lightning.callbacks.RichProgressBar 36 | init_args: 37 | refresh_rate: 1 38 | leave: True 39 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 40 | init_args: 41 | monitor: val_loss 42 | verbose: False 43 | save_top_k: 1 44 | mode: min 45 | save_last: True 46 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 47 | init_args: 48 | logging_interval: epoch 49 | log_every_n_steps: 10 50 | -------------------------------------------------------------------------------- /configs/models/NLinear/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: timeprophet.experiments.forecasting.LongTermForecasting 3 | init_args: 4 | model: 5 | class_path: timeprophet.models.NLinear.NLinear 6 | init_args: 7 | individual: False 8 | log_forecast: False 9 | 10 | optimizer: 11 | class_path: torch.optim.Adam 12 | init_args: 13 | lr: 1e-3 14 | 15 | lr_scheduler: 16 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 17 | init_args: 18 | T_max: 50 19 | eta_min: 0 20 | 21 | trainer: 22 | max_epochs: 50 23 | precision: 32 24 | logger: 25 | - class_path: timeprophet.logger.TensorBoardLogger 26 | init_args: 27 | save_dir: logs 28 | callbacks: 29 | - class_path: timeprophet.utils.callbacks.TemperatureScaling 30 | init_args: 31 | verbose: True 32 | - class_path: pytorch_lightning.callbacks.RichModelSummary 33 | init_args: 34 | max_depth: -1 35 | - class_path: pytorch_lightning.callbacks.RichProgressBar 36 | init_args: 37 | refresh_rate: 1 38 | leave: True 39 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 40 | init_args: 41 | monitor: val_loss 42 | verbose: False 43 | save_top_k: 1 44 | mode: min 45 | save_last: True 46 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 47 | init_args: 48 | logging_interval: epoch 49 | log_every_n_steps: 10 50 | -------------------------------------------------------------------------------- /timeprophet/data_modules/multivar.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from .base import TimeSeriesDataModule 4 | 5 | 6 | class MultivarDataModule(TimeSeriesDataModule): 7 | """DataModule for Multivariate Time series Datasets. 8 | 9 | For further information, please refer to the following repository: 10 | https://github.com/laiguokun/multivariate-time-series-data 11 | """ 12 | 13 | def __read_data__(self) -> pd.DataFrame: 14 | return pd.read_csv(self.dataset_path).drop('date', axis=1) 15 | 16 | def __split_data__(self, all_data: pd.DataFrame) -> tuple[pd.DataFrame]: 17 | """split the data into training, validation and test sets. 18 | 19 | we follow the method used in Informer: 20 | https://github.com/zhouhaoyi/Informer2020/blob/0ac81e04d4095ecb97a3a78c7b49c936d8aa9933/data/data_loader.py#L50 21 | https://github.com/zhouhaoyi/Informer2020/blob/0ac81e04d4095ecb97a3a78c7b49c936d8aa9933/data/data_loader.py#L136 22 | """ 23 | 24 | train_len = int(len(all_data) * 0.7) 25 | test_len = int(len(all_data) * 0.2) 26 | val_len = len(all_data) - train_len - test_len 27 | 28 | train_data = all_data[:train_len] 29 | val_data = all_data[train_len - self.input_len:train_len + val_len] 30 | test_data = all_data[train_len + val_len - self.input_len:train_len + 31 | val_len + test_len] 32 | 33 | return train_data, val_data, test_data 34 | -------------------------------------------------------------------------------- /configs/models/RLinear/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: timeprophet.experiments.forecasting.LongTermForecasting 3 | init_args: 4 | model: 5 | class_path: timeprophet.models.RLinear 6 | init_args: 7 | individual: False 8 | dropout_p: 0.1 9 | log_forecast: False 10 | 11 | optimizer: 12 | class_path: torch.optim.Adam 13 | init_args: 14 | lr: 1e-3 15 | 16 | lr_scheduler: 17 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 18 | init_args: 19 | T_max: 50 20 | eta_min: 0 21 | 22 | trainer: 23 | max_epochs: 50 24 | precision: 32 25 | logger: 26 | - class_path: timeprophet.logger.TensorBoardLogger 27 | init_args: 28 | save_dir: logs 29 | callbacks: 30 | - class_path: timeprophet.utils.callbacks.TemperatureScaling 31 | init_args: 32 | verbose: True 33 | - class_path: pytorch_lightning.callbacks.RichModelSummary 34 | init_args: 35 | max_depth: -1 36 | - class_path: pytorch_lightning.callbacks.RichProgressBar 37 | init_args: 38 | refresh_rate: 1 39 | leave: True 40 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 41 | init_args: 42 | monitor: val_loss 43 | verbose: False 44 | save_top_k: 1 45 | mode: min 46 | save_last: True 47 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 48 | init_args: 49 | logging_interval: epoch 50 | log_every_n_steps: 10 51 | -------------------------------------------------------------------------------- /timeprophet/data_modules/ETDataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from .base import TimeSeriesDataModule 4 | 5 | __all__ = ['ETDataModule'] 6 | 7 | 8 | class ETDataModule(TimeSeriesDataModule): 9 | """DataModule for ETDataset. 10 | 11 | For further information, please refer to the following repository: 12 | https://github.com/zhouhaoyi/ETDataset 13 | """ 14 | 15 | def __read_data__(self) -> pd.DataFrame: 16 | return pd.read_csv(self.dataset_path).drop('date', axis=1) 17 | 18 | def __split_data__(self, all_data: pd.DataFrame) -> tuple[pd.DataFrame]: 19 | """split the data into training, validation and test sets. 20 | 21 | we follow the method used in Informer: 22 | https://github.com/zhouhaoyi/Informer2020/blob/0ac81e04d4095ecb97a3a78c7b49c936d8aa9933/data/data_loader.py#L50 23 | https://github.com/zhouhaoyi/Informer2020/blob/0ac81e04d4095ecb97a3a78c7b49c936d8aa9933/data/data_loader.py#L136 24 | """ 25 | if 'ETTh1' in self.dataset_path or 'ETTh2' in self.dataset_path: 26 | train_len = 12 * 30 * 24 27 | val_len = 4 * 30 * 24 28 | test_len = 4 * 30 * 24 29 | elif 'ETTm1' in self.dataset_path or 'ETTm2' in self.dataset_path: 30 | train_len = 12 * 30 * 24 * 4 31 | val_len = 4 * 30 * 24 * 4 32 | test_len = 4 * 30 * 24 * 4 33 | else: 34 | raise ValueError 35 | 36 | train_data = all_data[:train_len] 37 | val_data = all_data[train_len - self.input_len:train_len + val_len] 38 | test_data = all_data[train_len + val_len - self.input_len:train_len + 39 | val_len + test_len] 40 | 41 | return train_data, val_data, test_data 42 | -------------------------------------------------------------------------------- /scripts/DLinear.sh: -------------------------------------------------------------------------------- 1 | for length in 96 192 336 720; do 2 | 3 | # ETTh1 4 | python main.py \ 5 | --config configs/models/DLinear/base.yaml \ 6 | --config configs/datasets/ETTh1/720_$length.yaml \ 7 | --config configs/gpu.yaml 8 | 9 | # ETTh2 10 | python main.py \ 11 | --config configs/models/DLinear/base.yaml \ 12 | --config configs/datasets/ETTm2/720_$length.yaml \ 13 | --config configs/gpu.yaml 14 | 15 | # ETTm1 16 | python main.py \ 17 | --config configs/models/DLinear/base.yaml \ 18 | --config configs/datasets/ETTm1/720_$length.yaml \ 19 | --config configs/gpu.yaml 20 | 21 | # ETTm2 22 | python main.py \ 23 | --config configs/models/DLinear/base.yaml \ 24 | --config configs/datasets/Electricity/720_$length.yaml \ 25 | --config configs/gpu.yaml 26 | 27 | # Electricity 28 | python main.py \ 29 | --config configs/models/DLinear/base.yaml \ 30 | --config configs/datasets/Electricity/720_$length.yaml \ 31 | --config configs/gpu.yaml 32 | 33 | # Weather 34 | python main.py \ 35 | --config configs/models/DLinear/base.yaml \ 36 | --config configs/datasets/Weather/720_$length.yaml \ 37 | --config configs/gpu.yaml \ 38 | --model.model.individual True 39 | 40 | done 41 | 42 | for length in 24 36 48 60; do 43 | # Illness 44 | python main.py \ 45 | --config configs/models/DLinear/base.yaml \ 46 | --config configs/datasets/Illness/60_$length.yaml \ 47 | --config configs/gpu.yaml 48 | 49 | # M5 50 | python main.py \ 51 | --config configs/models/DLinear/base.yaml \ 52 | --config configs/datasets/M5/60_$length.yaml \ 53 | --config configs/gpu.yaml 54 | done 55 | -------------------------------------------------------------------------------- /scripts/NLinear.sh: -------------------------------------------------------------------------------- 1 | for length in 96 192 336 720; do 2 | 3 | # ETTh1 4 | python main.py \ 5 | --config configs/models/NLinear/base.yaml \ 6 | --config configs/datasets/ETTh1/720_$length.yaml \ 7 | --config configs/gpu.yaml 8 | 9 | # ETTh2 10 | python main.py \ 11 | --config configs/models/NLinear/base.yaml \ 12 | --config configs/datasets/ETTm2/720_$length.yaml \ 13 | --config configs/gpu.yaml 14 | 15 | # ETTm1 16 | python main.py \ 17 | --config configs/models/NLinear/base.yaml \ 18 | --config configs/datasets/ETTm1/720_$length.yaml \ 19 | --config configs/gpu.yaml 20 | 21 | # ETTm2 22 | python main.py \ 23 | --config configs/models/NLinear/base.yaml \ 24 | --config configs/datasets/Electricity/720_$length.yaml \ 25 | --config configs/gpu.yaml 26 | 27 | # Electricity 28 | python main.py \ 29 | --config configs/models/NLinear/base.yaml \ 30 | --config configs/datasets/Electricity/720_$length.yaml \ 31 | --config configs/gpu.yaml 32 | 33 | # Weather 34 | python main.py \ 35 | --config configs/models/NLinear/base.yaml \ 36 | --config configs/datasets/Weather/720_$length.yaml \ 37 | --config configs/gpu.yaml \ 38 | --model.model.individual True 39 | 40 | done 41 | 42 | for length in 24 36 48 60; do 43 | # Illness 44 | python main.py \ 45 | --config configs/models/NLinear/base.yaml \ 46 | --config configs/datasets/Illness/60_$length.yaml \ 47 | --config configs/gpu.yaml 48 | 49 | # M5 50 | python main.py \ 51 | --config configs/models/NLinear/base.yaml \ 52 | --config configs/datasets/M5/60_$length.yaml \ 53 | --config configs/gpu.yaml 54 | done 55 | -------------------------------------------------------------------------------- /scripts/RLinear.sh: -------------------------------------------------------------------------------- 1 | for length in 96 192 336 720; do 2 | 3 | # ETTh1 4 | python main.py \ 5 | --config configs/models/RLinear/base.yaml \ 6 | --config configs/datasets/ETTh1/720_$length.yaml \ 7 | --config configs/gpu.yaml 8 | 9 | # ETTh2 10 | python main.py \ 11 | --config configs/models/RLinear/base.yaml \ 12 | --config configs/datasets/ETTm2/720_$length.yaml \ 13 | --config configs/gpu.yaml 14 | 15 | # ETTm1 16 | python main.py \ 17 | --config configs/models/RLinear/base.yaml \ 18 | --config configs/datasets/ETTm1/720_$length.yaml \ 19 | --config configs/gpu.yaml 20 | 21 | # ETTm2 22 | python main.py \ 23 | --config configs/models/RLinear/base.yaml \ 24 | --config configs/datasets/Electricity/720_$length.yaml \ 25 | --config configs/gpu.yaml 26 | 27 | # Electricity 28 | python main.py \ 29 | --config configs/models/RLinear/base.yaml \ 30 | --config configs/datasets/Electricity/720_$length.yaml \ 31 | --config configs/gpu.yaml 32 | 33 | # Weather 34 | python main.py \ 35 | --config configs/models/RLinear/base.yaml \ 36 | --config configs/datasets/Weather/720_$length.yaml \ 37 | --config configs/gpu.yaml \ 38 | --model.model.individual True 39 | 40 | done 41 | 42 | for length in 24 36 48 60; do 43 | # Illness 44 | python main.py \ 45 | --config configs/models/RLinear/base.yaml \ 46 | --config configs/datasets/Illness/60_$length.yaml \ 47 | --config configs/gpu.yaml 48 | 49 | # M5 50 | python main.py \ 51 | --config configs/models/RLinear/base.yaml \ 52 | --config configs/datasets/M5/60_$length.yaml \ 53 | --config configs/gpu.yaml 54 | done 55 | -------------------------------------------------------------------------------- /timeprophet/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch_lightning.callbacks import Callback, EarlyStopping, LearningRateMonitor 3 | 4 | 5 | class AdjustLearningRate(Callback): 6 | 7 | def __init__(self, base_lr: float, verbose: bool = False): 8 | self.base_lr = base_lr 9 | self.verbose = verbose 10 | 11 | def on_train_epoch_end(self, trainer, pl_module): 12 | epoch = trainer.current_epoch 13 | lr = self.base_lr * (0.95**((epoch) // 1)) 14 | for param_group in trainer.optimizers[0].param_groups: 15 | param_group['lr'] = lr 16 | if self.verbose: 17 | print(f"Learning rate adjusted to {lr}") 18 | 19 | 20 | class TrainEarlyStopping(EarlyStopping): 21 | 22 | def on_validation_end(self, trainer, pl_module): 23 | # override this to disable early stopping at the end of val loop 24 | pass 25 | 26 | def on_train_end(self, trainer, pl_module): 27 | # instead, do it at the end of training loop 28 | self._run_early_stopping_check(trainer) 29 | 30 | 31 | class TemperatureScaling(Callback): 32 | 33 | def __init__(self, verbose: bool = False): 34 | self.verbose = verbose 35 | self.min = 1 36 | 37 | def on_train_epoch_start(self, trainer, pl_module): 38 | t_max = 30 39 | t_min = self.min 40 | max_epoch = 10 41 | epoch = trainer.current_epoch 42 | if epoch > max_epoch: 43 | temperature = t_min 44 | else: 45 | temperature = t_max - (t_max - t_min) * epoch / max_epoch 46 | 47 | pl_module.model.temperature = temperature 48 | # print(pl_module.model.temperature) 49 | 50 | def on_test_epoch_start(self, trainer, pl_module): 51 | pl_module.model.temperature = self.min 52 | pass 53 | -------------------------------------------------------------------------------- /scripts/FITS.sh: -------------------------------------------------------------------------------- 1 | for length in 96 192 336 720; do 2 | 3 | # ETTh1 4 | python main.py \ 5 | --config configs/models/FITS/base.yaml \ 6 | --config configs/datasets/ETTh1/720_$length.yaml \ 7 | --config configs/gpu.yaml \ 8 | --model.model.base_T 24 --model.model.h_order 6 9 | 10 | # ETTh2 11 | python main.py \ 12 | --config configs/models/FITS/base.yaml \ 13 | --config configs/datasets/ETTm2/720_$length.yaml \ 14 | --config configs/gpu.yaml \ 15 | --model.model.base_T 24 --model.model.h_order 6 16 | 17 | # ETTm1 18 | python main.py \ 19 | --config configs/models/FITS/base.yaml \ 20 | --config configs/datasets/ETTm1/720_$length.yaml \ 21 | --config configs/gpu.yaml \ 22 | --model.model.base_T 96 --model.model.h_order 14 23 | 24 | # ETTm2 25 | python main.py \ 26 | --config configs/models/FITS/base.yaml \ 27 | --config configs/datasets/Electricity/720_$length.yaml \ 28 | --config configs/gpu.yaml \ 29 | --model.model.base_T 96 --model.model.h_order 14 30 | 31 | # Electricity 32 | python main.py \ 33 | --config configs/models/FITS/base.yaml \ 34 | --config configs/datasets/Electricity/720_$length.yaml \ 35 | --config configs/gpu.yaml \ 36 | --model.model.base_T 24 --model.model.h_order 10 37 | 38 | # Weather 39 | python main.py \ 40 | --config configs/models/FITS/base.yaml \ 41 | --config configs/datasets/Weather/720_$length.yaml \ 42 | --config configs/gpu.yaml \ 43 | --model.model.individual True \ 44 | --model.model.base_T 144 --model.model.h_order 12 45 | 46 | done 47 | 48 | for length in 24 36 48 60; do 49 | # Illness 50 | python main.py \ 51 | --config configs/models/FITS/base.yaml \ 52 | --config configs/datasets/Illness/60_$length.yaml \ 53 | --config configs/gpu.yaml \ 54 | --model.model.base_T 52 --model.model.h_order 10 55 | 56 | # M5 57 | python main.py \ 58 | --config configs/models/FITS/base.yaml \ 59 | --config configs/datasets/M5/60_$length.yaml \ 60 | --config configs/gpu.yaml \ 61 | --model.model.base_T 7 --model.model.h_order 2 62 | done 63 | -------------------------------------------------------------------------------- /timeprophet/models/NLinear.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2022 DLinear Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class NLinear(nn.Module): 23 | """ 24 | Normalization-Linear 25 | """ 26 | 27 | def __init__(self, input_len, output_len, individual, input_features, 28 | output_features): 29 | super(NLinear, self).__init__() 30 | self.seq_len = input_len 31 | self.pred_len = output_len 32 | self.example_input_array = torch.Tensor(32, input_len, input_features) 33 | 34 | # Use this line if you want to visualize the weights 35 | # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 36 | self.channels = input_features 37 | self.individual = individual 38 | if self.individual: 39 | self.Linear = nn.ModuleList() 40 | for i in range(self.channels): 41 | self.Linear.append(nn.Linear(self.seq_len, self.pred_len)) 42 | else: 43 | self.Linear = nn.Linear(self.seq_len, self.pred_len) 44 | 45 | def forward(self, x): 46 | # x: [Batch, Input length, Channel] 47 | seq_last = x[:, -1:, :].detach() 48 | x = x - seq_last 49 | if self.individual: 50 | output = torch.zeros( 51 | [x.size(0), self.pred_len, x.size(2)], 52 | dtype=x.dtype).to(x.device) 53 | for i in range(self.channels): 54 | output[:, :, i] = self.Linear[i](x[:, :, i]) 55 | x = output 56 | else: 57 | x = self.Linear(x.permute(0, 2, 1)).permute(0, 2, 1) 58 | x = x + seq_last 59 | return x # [Batch, Output length, Channel] 60 | 61 | def loss(self, y, y_hat): 62 | return F.mse_loss(y, y_hat) 63 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import yaml 5 | from pytorch_lightning.cli import LightningCLI 6 | 7 | 8 | class CLI(LightningCLI): 9 | 10 | def add_arguments_to_parser(self, parser): 11 | parser.link_arguments("data.init_args.input_len", 12 | "model.init_args.model.init_args.input_len") 13 | parser.link_arguments("data.init_args.output_len", 14 | "model.init_args.model.init_args.output_len") 15 | 16 | parser.link_arguments("data.x_features_num", 17 | "model.init_args.model.init_args.input_features", 18 | apply_on="instantiate") 19 | parser.link_arguments("data.y_features_num", 20 | "model.init_args.model.init_args.output_features", 21 | apply_on="instantiate") 22 | 23 | parser.link_arguments("model.class_path", 24 | "trainer.logger.init_args.task", 25 | compute_fn=lambda s: s.split('.')[-1].strip()) 26 | parser.link_arguments("model.init_args.model.class_path", 27 | "trainer.logger.init_args.model", 28 | compute_fn=lambda s: s.split('.')[-1].strip()) 29 | parser.link_arguments( 30 | "data.init_args.dataset_path", 31 | "trainer.logger.init_args.dataset", 32 | compute_fn=lambda s: s.split('/')[-1].split('.')[0].strip()) 33 | parser.link_arguments("data.init_args.input_len", 34 | "trainer.logger.init_args.input_length") 35 | parser.link_arguments("data.init_args.output_len", 36 | "trainer.logger.init_args.output_length") 37 | 38 | 39 | torch.set_float32_matmul_precision("high") 40 | 41 | cli = CLI(run=False) 42 | 43 | cli.trainer.fit(cli.model, cli.datamodule) 44 | best_model_path = cli.trainer.checkpoint_callback.best_model_path 45 | last_model_path = cli.trainer.checkpoint_callback.last_model_path 46 | 47 | best_result = cli.trainer.test( 48 | cli.model, 49 | datamodule=cli.datamodule, 50 | ckpt_path=best_model_path, 51 | ) 52 | 53 | last_result = cli.trainer.test( 54 | cli.model, 55 | datamodule=cli.datamodule, 56 | ckpt_path=last_model_path, 57 | ) 58 | 59 | result = { 60 | "best": best_result, 61 | "last": last_result, 62 | "best_path": best_model_path 63 | } 64 | with open(os.path.join(cli.trainer.logger.log_dir, "test_result.yaml"), 65 | "w", 66 | encoding='utf-8') as f: 67 | yaml.dump(result, f) 68 | -------------------------------------------------------------------------------- /scripts/DiPE.sh: -------------------------------------------------------------------------------- 1 | for length in 96 192 336 720; do 2 | 3 | # ETTh1 4 | python main.py \ 5 | --config configs/models/DiPE/base.yaml \ 6 | --config configs/datasets/ETTh1/720_$length.yaml \ 7 | --config configs/gpu.yaml \ 8 | --model.model.num_experts 1 \ 9 | --model.model.loss_alpha 1 \ 10 | --model.model.use_revin True 11 | 12 | # ETTh2 13 | python main.py \ 14 | --config configs/models/DiPE/base.yaml \ 15 | --config configs/datasets/ETTm2/720_$length.yaml \ 16 | --config configs/gpu.yaml \ 17 | --model.model.num_experts 3 \ 18 | --model.model.loss_alpha 0.9 \ 19 | --model.model.use_revin True 20 | 21 | # ETTm1 22 | python main.py \ 23 | --config configs/models/DiPE/base.yaml \ 24 | --config configs/datasets/ETTm1/720_$length.yaml \ 25 | --config configs/gpu.yaml \ 26 | --model.model.num_experts 1 \ 27 | --model.model.loss_alpha 1 \ 28 | --model.model.use_revin True 29 | 30 | # ETTm2 31 | python main.py \ 32 | --config configs/models/DiPE/base.yaml \ 33 | --config configs/datasets/Electricity/720_$length.yaml \ 34 | --config configs/gpu.yaml \ 35 | --model.model.num_experts 1 \ 36 | --model.model.loss_alpha 0.9 \ 37 | --model.model.use_revin True 38 | 39 | # Electricity 40 | python main.py \ 41 | --config configs/models/DiPE/base.yaml \ 42 | --config configs/datasets/Electricity/720_$length.yaml \ 43 | --config configs/gpu.yaml \ 44 | --model.model.num_experts 4 \ 45 | --model.model.loss_alpha 0.3 \ 46 | --model.model.use_revin False 47 | 48 | # Weather 49 | python main.py \ 50 | --config configs/models/DiPE/base.yaml \ 51 | --config configs/datasets/Weather/720_$length.yaml \ 52 | --config configs/gpu.yaml \ 53 | --model.model.num_experts 4 \ 54 | --model.model.loss_alpha 0.9 \ 55 | --model.model.use_revin False 56 | 57 | done 58 | 59 | for length in 24 36 48 60; do 60 | # Illness 61 | python main.py \ 62 | --config configs/models/DiPE/base.yaml \ 63 | --config configs/datasets/Illness/60_$length.yaml \ 64 | --config configs/gpu.yaml \ 65 | --model.model.num_experts 1 \ 66 | --model.model.loss_alpha 1 \ 67 | --model.model.use_revin True 68 | 69 | # M5 70 | python main.py \ 71 | --config configs/models/DiPE/base.yaml \ 72 | --config configs/datasets/M5/60_$length.yaml \ 73 | --config configs/gpu.yaml \ 74 | --model.model.num_experts 4 \ 75 | --model.model.loss_alpha 0 \ 76 | --model.model.use_revin True 77 | done 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | lightning_logs/ 165 | checkpoints/ 166 | events.out.tfevents.* 167 | logs/ 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Disentangled Interpretable Representation for Efficient Long-term Time Series Forecasting 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2411.17257-B31B1B.svg?logo=arxiv)](https://arxiv.org/abs/2411.17257) 4 | [![DOI](https://img.shields.io/badge/DOI-10.48550/arXiv.2411.17257-FAB70C.svg?logo=DOI)](https://doi.org/10.48550/arXiv.2411.17257) 5 | [![license](https://img.shields.io/github/license/wintertee/DiPE-Linear?style=flat)](https://github.com/wintertee/DiPE-Linear/blob/main/LICENSE) 6 | 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/disentangled-interpretable-representation-for/time-series-forecasting-on-etth1-720-1)](https://paperswithcode.com/sota/time-series-forecasting-on-etth1-720-1?p=disentangled-interpretable-representation-for) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/disentangled-interpretable-representation-for/time-series-forecasting-on-etth2-720-1)](https://paperswithcode.com/sota/time-series-forecasting-on-etth2-720-1?p=disentangled-interpretable-representation-for) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/disentangled-interpretable-representation-for/time-series-forecasting-on-ettm1-720-1)](https://paperswithcode.com/sota/time-series-forecasting-on-ettm1-720-1?p=disentangled-interpretable-representation-for) 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/disentangled-interpretable-representation-for/time-series-forecasting-on-ettm2-720-1)](https://paperswithcode.com/sota/time-series-forecasting-on-ettm2-720-1?p=disentangled-interpretable-representation-for) 12 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/disentangled-interpretable-representation-for/time-series-forecasting-on-weather-720)](https://paperswithcode.com/sota/time-series-forecasting-on-weather-720?p=disentangled-interpretable-representation-for) 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/disentangled-interpretable-representation-for/time-series-forecasting-on-electricity-720)](https://paperswithcode.com/sota/time-series-forecasting-on-electricity-720?p=disentangled-interpretable-representation-for) 14 | 15 | The official implementation of paper "Disentangled Interpretable Representation for Efficient Long-term Time Series Forecasting" 16 | 17 | ## Requirements 18 | 19 | We recommend using the latest versions of dependencies. However, you can refer to the `environment.yml` file to set up the same environment as we used. 20 | 21 | ## Dataset 22 | 23 | All datasets are stored as CSV files and compressed in GZ format. Please place the datasets in the `./dataset` directory. 24 | 25 | - For the M5 dataset, we recommend downloading it from [M5-methods](https://github.com/Mcompetitions/M5-methods) and preprocessing it using `preprocessing/M5.py`. 26 | - For other datasets, we recommend downloading them from [Autoformer](https://github.com/thuml/Autoformer). 27 | 28 | ## Usage 29 | 30 | All experiments can be reproduced using the `scripts/DiPE.sh` script. 31 | 32 | ## Citation 33 | 34 | If you find this repo useful, please cite our paper: 35 | 36 | ```bibtex 37 | @misc{zhao2024dipe, 38 | title={Disentangled Interpretable Representation for Efficient Long-term Time Series Forecasting}, 39 | author={Yuang Zhao and Tianyu Li and Jiadong Chen and Shenrong Ye and Fuxin Jiang and Tieying Zhang and Xiaofeng Gao}, 40 | year={2024}, 41 | eprint={2411.17257}, 42 | archivePrefix={arXiv}, 43 | primaryClass={cs.LG}, 44 | url={https://arxiv.org/abs/2411.17257}, 45 | } 46 | ``` 47 | 48 | ## License 49 | 50 | This repo is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 51 | 52 | ## Acknowledgments 53 | 54 | - [ETDataset](https://github.com/zhouhaoyi/ETDataset) 55 | - [multivariate-time-series-data](https://github.com/laiguokun/multivariate-time-series-data) 56 | - [Autoformer](https://github.com/thuml/Autoformer) 57 | - [LTSF-Linear](https://github.com/cure-lab/LTSF-Linear) 58 | - [RTSF](https://github.com/plumprc/RTSF) 59 | - [FITS](https://github.com/VEWOXIC/FITS) 60 | -------------------------------------------------------------------------------- /timeprophet/models/FITS.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2023 VEWOXIC 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class FITS(nn.Module): 23 | 24 | # FITS: Frequency Interpolation Time Series Forecasting 25 | 26 | def __init__( 27 | self, 28 | input_len: int, 29 | output_len: int, 30 | input_features: int, 31 | output_features: int, 32 | individual: bool, 33 | base_T: int, 34 | h_order: int, 35 | ): 36 | super().__init__() 37 | self.example_input_array = torch.Tensor(32, input_len, input_features) 38 | self.seq_len = input_len 39 | self.pred_len = output_len 40 | self.individual = individual 41 | self.channels = input_features 42 | 43 | self.dominance_freq = (self.seq_len // base_T + 1) * h_order + 10 44 | self.length_ratio = (self.seq_len + self.pred_len) / self.seq_len 45 | 46 | if self.individual: 47 | self.freq_upsampler = nn.ModuleList() 48 | for i in range(self.channels): 49 | self.freq_upsampler.append( 50 | nn.Linear(self.dominance_freq, 51 | int(self.dominance_freq * self.length_ratio)).to( 52 | torch.cfloat)) 53 | 54 | else: 55 | self.freq_upsampler = nn.Linear( 56 | self.dominance_freq, 57 | int(self.dominance_freq * self.length_ratio)).to( 58 | torch.cfloat) # complex layer for frequency upcampling] 59 | # configs.pred_len=configs.seq_len+configs.pred_len 60 | # #self.Dlinear=DLinear.Model(configs) 61 | # configs.pred_len=self.pred_len 62 | 63 | def forward(self, x): 64 | # RIN 65 | x_mean = torch.mean(x, dim=1, keepdim=True) 66 | x = x - x_mean 67 | x_var = torch.var(x, dim=1, keepdim=True) + 1e-5 68 | # print(x_var) 69 | x = x / torch.sqrt(x_var) 70 | 71 | low_specx = torch.fft.rfft(x, dim=1) 72 | low_specx[:, self.dominance_freq:] = 0 # LPF 73 | low_specx = low_specx[:, 0:self.dominance_freq, :] # LPF 74 | # print(low_specx.permute(0,2,1)) 75 | if self.individual: 76 | low_specxy_ = torch.zeros([ 77 | low_specx.size(0), 78 | int(self.dominance_freq * self.length_ratio), 79 | low_specx.size(2) 80 | ], 81 | dtype=low_specx.dtype).to( 82 | low_specx.device) 83 | for i in range(self.channels): 84 | low_specxy_[:, :, i] = self.freq_upsampler[i]( 85 | low_specx[:, :, i].permute(0, 1)).permute(0, 1) 86 | else: 87 | low_specxy_ = self.freq_upsampler(low_specx.permute(0, 2, 88 | 1)).permute( 89 | 0, 2, 1) 90 | # print(low_specxy_) 91 | low_specxy = torch.zeros([ 92 | low_specxy_.size(0), 93 | int((self.seq_len + self.pred_len) / 2 + 1), 94 | low_specxy_.size(2) 95 | ], 96 | dtype=low_specxy_.dtype).to(low_specxy_.device) 97 | low_specxy[:, 0:low_specxy_.size(1), :] = low_specxy_ # zero padding 98 | low_xy = torch.fft.irfft(low_specxy, dim=1) 99 | low_xy = low_xy * self.length_ratio # energy compemsation for the length change 100 | # dom_x=x-low_x 101 | 102 | # dom_xy=self.Dlinear(dom_x) 103 | # xy=(low_xy+dom_xy) * torch.sqrt(x_var) +x_mean # REVERSE RIN 104 | xy = (low_xy) * torch.sqrt(x_var) + x_mean 105 | return xy[:, -self.pred_len:, :] 106 | 107 | def loss(self, y, y_hat): 108 | return F.mse_loss(y, y_hat) 109 | -------------------------------------------------------------------------------- /timeprophet/models/DLinear.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2022 DLinear Authors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class moving_avg(nn.Module): 23 | """ 24 | Moving average block to highlight the trend of time series 25 | """ 26 | 27 | def __init__(self, kernel_size, stride): 28 | super(moving_avg, self).__init__() 29 | self.kernel_size = kernel_size 30 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, 31 | stride=stride, 32 | padding=0) 33 | 34 | def forward(self, x): 35 | # padding on the both ends of time series 36 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 37 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 38 | x = torch.cat([front, x, end], dim=1) 39 | x = self.avg(x.permute(0, 2, 1)) 40 | x = x.permute(0, 2, 1) 41 | return x 42 | 43 | 44 | class series_decomp(nn.Module): 45 | """ 46 | Series decomposition block 47 | """ 48 | 49 | def __init__(self, kernel_size): 50 | super(series_decomp, self).__init__() 51 | self.moving_avg = moving_avg(kernel_size, stride=1) 52 | 53 | def forward(self, x): 54 | moving_mean = self.moving_avg(x) 55 | res = x - moving_mean 56 | return res, moving_mean 57 | 58 | 59 | class DLinear(nn.Module): 60 | """ 61 | Decomposition-Linear 62 | """ 63 | 64 | def __init__(self, input_len, output_len, individual, input_features, 65 | output_features): 66 | super().__init__() 67 | self.input_len = input_len 68 | self.output_len = output_len 69 | assert input_features == output_features 70 | 71 | # Decompsition Kernel Size 72 | kernel_size = 25 73 | self.decompsition = series_decomp(kernel_size) 74 | self.individual = individual 75 | self.channels = input_features 76 | 77 | self.example_input_array = torch.Tensor(32, input_len, input_features) 78 | 79 | if self.individual: 80 | self.linear_seasonal = nn.ModuleList() 81 | self.linear_trend = nn.ModuleList() 82 | 83 | for i in range(self.channels): 84 | self.linear_seasonal.append( 85 | nn.Linear(self.input_len, self.output_len)) 86 | self.linear_trend.append( 87 | nn.Linear(self.input_len, self.output_len)) 88 | 89 | # Use this two lines if you want to visualize the weights 90 | # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 91 | # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 92 | else: 93 | self.linear_seasonal = nn.Linear(self.input_len, self.output_len) 94 | self.linear_trend = nn.Linear(self.input_len, self.output_len) 95 | 96 | # Use this two lines if you want to visualize the weights 97 | # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 98 | # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 99 | 100 | def forward(self, x): 101 | # x: [Batch, Input length, Channel] 102 | seasonal_init, trend_init = self.decompsition(x) 103 | seasonal_init, trend_init = seasonal_init.permute( 104 | 0, 2, 1), trend_init.permute(0, 2, 1) 105 | if self.individual: 106 | seasonal_output = torch.zeros( 107 | [seasonal_init.size(0), 108 | seasonal_init.size(1), self.output_len], 109 | dtype=seasonal_init.dtype).to(seasonal_init.device) 110 | trend_output = torch.zeros( 111 | [trend_init.size(0), 112 | trend_init.size(1), self.output_len], 113 | dtype=trend_init.dtype).to(trend_init.device) 114 | for i in range(self.channels): 115 | seasonal_output[:, i, :] = self.linear_seasonal[i]( 116 | seasonal_init[:, i, :]) 117 | trend_output[:, i, :] = self.linear_trend[i](trend_init[:, 118 | i, :]) 119 | else: 120 | seasonal_output = self.linear_seasonal(seasonal_init) 121 | trend_output = self.linear_trend(trend_init) 122 | 123 | x = seasonal_output + trend_output 124 | return x.permute(0, 2, 1) # to [Batch, Output length, Channel] 125 | 126 | def loss(self, y, y_hat): 127 | return F.mse_loss(y, y_hat) 128 | -------------------------------------------------------------------------------- /timeprophet/experiments/forecasting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pytorch_lightning as L 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor, nn 7 | 8 | 9 | class LongTermForecasting(L.LightningModule): 10 | 11 | def __init__(self, 12 | model: nn.Module, 13 | log_forecast: bool = False, 14 | profile=False): 15 | super().__init__() 16 | self.model = model 17 | self.example_input_array = self.model.example_input_array 18 | self.log_forecast = log_forecast 19 | self.profile = profile 20 | self.profile_time = [] 21 | 22 | self.save_hyperparameters() 23 | 24 | self.metrics_fn = { 25 | 'mse': F.mse_loss, 26 | 'mae': F.l1_loss, 27 | } 28 | 29 | def forward(self, x) -> Tensor: 30 | return self.model(x) 31 | 32 | def shared_step(self, x, y): 33 | y_hat = self.forward(x) 34 | 35 | with torch.no_grad(): 36 | metrics = { 37 | metric_name: metric_fn(y_hat, y) 38 | for metric_name, metric_fn in self.metrics_fn.items() 39 | } 40 | 41 | return metrics, y_hat 42 | 43 | def training_step(self, batch, batch_idx): 44 | x, y = batch 45 | metrics, y_hat = self.shared_step(x, y) 46 | loss = self.model.loss(y, y_hat) 47 | 48 | self.log_dict( 49 | { 50 | f'train_{metric_name}': metric_value 51 | for metric_name, metric_value in metrics.items() 52 | }, 53 | on_step=False, 54 | on_epoch=True, 55 | prog_bar=True, 56 | ) 57 | self.log( 58 | "train_loss", 59 | loss, 60 | on_step=False, 61 | on_epoch=True, 62 | prog_bar=True, 63 | ) 64 | 65 | self.log_add_forecast('train', batch_idx, batch[0], y_hat, batch[1]) 66 | 67 | return loss 68 | 69 | def validation_step(self, batch, batch_idx): 70 | x, y = batch 71 | metrics, y_hat = self.shared_step(x, y) 72 | loss = self.model.loss(y, y_hat) 73 | self.log_dict( 74 | { 75 | f'val_{metric_name}': metric_value 76 | for metric_name, metric_value in metrics.items() 77 | }, 78 | on_step=False, 79 | on_epoch=True, 80 | prog_bar=True, 81 | ) 82 | self.log( 83 | "val_loss", 84 | loss, 85 | on_step=False, 86 | on_epoch=True, 87 | prog_bar=True, 88 | ) 89 | 90 | self.log_add_forecast('val', batch_idx, batch[0], y_hat, batch[1]) 91 | 92 | def test_step(self, batch, batch_idx): 93 | x, y = batch 94 | 95 | if self.profile: 96 | torch.cuda.empty_cache() 97 | start_event = torch.cuda.Event(enable_timing=True) 98 | end_event = torch.cuda.Event(enable_timing=True) 99 | torch.cuda.synchronize() 100 | start_event.record() 101 | metrics, _ = self.shared_step(x, y) 102 | end_event.record() 103 | torch.cuda.synchronize() 104 | inference_time = start_event.elapsed_time(end_event) 105 | self.profile_time.append(inference_time) 106 | print(f"Inference time: {np.sum(self.profile_time)} ms") 107 | else: 108 | metrics, _ = self.shared_step(x, y) 109 | self.log_dict( 110 | { 111 | f'test_{metric_name}': metric_value 112 | for metric_name, metric_value in metrics.items() 113 | }, 114 | on_step=False, 115 | on_epoch=True, 116 | prog_bar=True, 117 | ) 118 | 119 | def log_add_forecast(self, name, batch_idx, x, y_hat, y): 120 | 121 | if (batch_idx % self.trainer.log_every_n_steps == 0 and 122 | self.log_forecast and 123 | isinstance(self.logger.experiment, 124 | torch.utils.tensorboard.writer.SummaryWriter)): 125 | 126 | tensorboard = self.logger.experiment 127 | 128 | x_np = x[0].detach().cpu().numpy() 129 | y_np = y[0].detach().cpu().numpy() 130 | y_hat_np = y_hat[0].detach().cpu().numpy() 131 | 132 | num_plots = min(10, x.shape[2]) 133 | rows = (num_plots + 2) // 3 134 | cols = 3 135 | 136 | fig, axs = plt.subplots(rows, cols, figsize=(12, 2 * rows)) 137 | 138 | if rows == 1: 139 | axs = np.expand_dims(axs, axis=0) 140 | 141 | for i in range(num_plots): 142 | row = i // cols 143 | col = i % cols 144 | ax = axs[row, col] 145 | 146 | ax.plot(range(self.model.input_len), 147 | x_np[:, i], 148 | label='Input Sequence', 149 | color='black') 150 | ax.plot(range(self.model.input_len, 151 | self.model.input_len + self.model.output_len), 152 | y_np[:, i], 153 | color='green') 154 | ax.plot(range(self.model.input_len, 155 | self.model.input_len + self.model.output_len), 156 | y_hat_np[:, i], 157 | label='Prediction', 158 | color='blue', 159 | linestyle='dashed', 160 | alpha=0.5) 161 | 162 | # ax.set_title(f'Forecast for Feature {i}') 163 | # ax.set_xlabel('Time Steps') 164 | # ax.set_ylabel('Value') 165 | 166 | for j in range(num_plots, rows * cols): 167 | fig.delaxes(axs[j // cols, j % cols]) 168 | 169 | plt.subplots_adjust(left=0.05, 170 | right=0.95, 171 | top=0.95, 172 | bottom=0.05, 173 | wspace=0.3, 174 | hspace=0.3) 175 | 176 | tensorboard.add_figure(name, fig, self.global_step) 177 | -------------------------------------------------------------------------------- /timeprophet/data_modules/base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Type 2 | 3 | import pandas as pd 4 | import pytorch_lightning as L 5 | import torch 6 | from sklearn.base import BaseEstimator 7 | from sklearn.preprocessing import StandardScaler 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class TimeSeriesDataset(Dataset): 12 | 13 | def __init__(self, data_x: torch.Tensor, data_y: torch.Tensor, 14 | input_len: int, output_len: int): 15 | 16 | assert len(data_x) == len(data_y) 17 | 18 | self.data_x = data_x 19 | self.data_y = data_y 20 | self.input_len = input_len 21 | self.output_len = output_len 22 | 23 | def __len__(self) -> int: 24 | return len(self.data_x) - self.input_len - self.output_len + 1 25 | 26 | def __getitem__(self, idx) -> tuple[torch.Tensor]: 27 | return (self.data_x[idx:idx + self.input_len], 28 | self.data_y[idx + self.input_len:idx + self.input_len + 29 | self.output_len]) 30 | 31 | 32 | class TimeSeriesDataModule(L.LightningDataModule): 33 | """This is a Dataset base class for time series data. 34 | 35 | """ 36 | 37 | _all_features_num = None 38 | 39 | def __init__(self, 40 | dataset_path: str, 41 | batch_size: int, 42 | input_len: int, 43 | output_len: int, 44 | x_features: List[int] | None = None, 45 | y_features: List[int] | None = None, 46 | all_features_num: int | None = None, 47 | preprocessor: Type[BaseEstimator] = StandardScaler, 48 | pin_memory: bool = True, 49 | num_workers: int = 0, 50 | persistent_workers: bool = False, 51 | gpu: bool = False, 52 | train_proportion: float = 1.0, 53 | down_sampling: int = 1): 54 | 55 | super().__init__() 56 | 57 | if x_features is None or y_features is None: 58 | assert all_features_num is not None 59 | 60 | self.dataset_path = dataset_path 61 | self.input_len = input_len 62 | self.output_len = output_len 63 | self.x_features = x_features 64 | self.y_features = y_features 65 | self.all_features_num = all_features_num 66 | self.preprocessor = preprocessor() 67 | self.batch_size = batch_size 68 | self.pin_memory = pin_memory 69 | self.num_workers = num_workers 70 | self.persistent_workers = persistent_workers 71 | self.gpu = gpu 72 | self.train_proportion = train_proportion 73 | self.down_sampling = down_sampling 74 | 75 | if self.gpu: 76 | self.pin_memory = False 77 | self.num_workers = 0 78 | self.persistent_workers = False 79 | 80 | if self.x_features is not None: 81 | self.x_features_num = len(self.x_features) 82 | else: 83 | self.x_features_num = self.all_features_num 84 | 85 | if self.y_features is not None: 86 | self.y_features_num = len(self.y_features) 87 | else: 88 | self.y_features_num = self.all_features_num 89 | 90 | self.is_setup = False 91 | 92 | def __read_data__(self) -> pd.DataFrame: 93 | raise NotImplementedError 94 | 95 | def __split_data__(self, all_data: pd.DataFrame) -> tuple[pd.DataFrame]: 96 | raise NotImplementedError 97 | 98 | def prepare_data(self) -> None: 99 | 100 | # read data in DataFrame Format 101 | all_data = self.__read_data__() 102 | 103 | train_data, val_data, test_data = self.__split_data__(all_data) 104 | 105 | # downsampling 106 | train_data = train_data.iloc[::self.down_sampling] 107 | val_data = val_data.iloc[::self.down_sampling] 108 | test_data = test_data.iloc[::self.down_sampling] 109 | 110 | # covert data to numpy array by sklearn preprocessor 111 | train_data = self.preprocessor.fit_transform(train_data) 112 | val_data = self.preprocessor.transform(val_data) 113 | test_data = self.preprocessor.transform(test_data) 114 | 115 | # convert data to float32 tensor 116 | train_data = torch.from_numpy(train_data).float() 117 | val_data = torch.from_numpy(val_data).float() 118 | test_data = torch.from_numpy(test_data).float() 119 | 120 | train_len = train_data.shape[0] 121 | train_data = train_data[:int(self.train_proportion * train_len), :] 122 | 123 | if self.gpu: 124 | train_data = train_data.cuda() 125 | val_data = val_data.cuda() 126 | test_data = test_data.cuda() 127 | 128 | if self.x_features is None: 129 | self.train_x = train_data 130 | self.val_x = val_data 131 | self.test_x = test_data 132 | else: 133 | self.train_x = train_data[:, self.x_features] 134 | self.val_x = val_data[:, self.x_features] 135 | self.test_x = test_data[:, self.x_features] 136 | 137 | if self.y_features is None: 138 | self.train_y = train_data 139 | self.val_y = val_data 140 | self.test_y = test_data 141 | else: 142 | self.train_y = train_data[:, self.y_features] 143 | self.val_y = val_data[:, self.y_features] 144 | self.test_y = test_data[:, self.y_features] 145 | 146 | def setup(self, stage: str) -> None: 147 | 148 | if not self.is_setup: 149 | self.train_dataset = TimeSeriesDataset( 150 | self.train_x, 151 | self.train_y, 152 | self.input_len, 153 | self.output_len, 154 | ) 155 | self.val_dataset = TimeSeriesDataset( 156 | self.val_x, 157 | self.val_y, 158 | self.input_len, 159 | self.output_len, 160 | ) 161 | self.test_dataset = TimeSeriesDataset( 162 | self.test_x, 163 | self.test_y, 164 | self.input_len, 165 | self.output_len, 166 | ) 167 | self.is_setup = True 168 | 169 | def train_dataloader(self) -> torch.utils.data.DataLoader: 170 | return torch.utils.data.DataLoader( 171 | self.train_dataset, 172 | batch_size=self.batch_size, 173 | shuffle=True, 174 | pin_memory=self.pin_memory, 175 | num_workers=self.num_workers, 176 | persistent_workers=self.persistent_workers, 177 | ) 178 | 179 | def val_dataloader(self) -> torch.utils.data.DataLoader: 180 | return torch.utils.data.DataLoader( 181 | self.val_dataset, 182 | batch_size=self.batch_size, 183 | shuffle=False, 184 | pin_memory=self.pin_memory, 185 | num_workers=self.num_workers, 186 | persistent_workers=self.persistent_workers, 187 | ) 188 | 189 | def test_dataloader(self) -> torch.utils.data.DataLoader: 190 | return torch.utils.data.DataLoader( 191 | self.test_dataset, 192 | batch_size=self.batch_size, 193 | shuffle=False, 194 | pin_memory=self.pin_memory, 195 | num_workers=self.num_workers, 196 | persistent_workers=self.persistent_workers, 197 | ) 198 | -------------------------------------------------------------------------------- /timeprophet/models/DiPE.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Literal 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from torch import Tensor, nn 8 | 9 | 10 | class Identity(nn.Module): 11 | 12 | def __init__(self, *args: Any, **kwargs: Any) -> None: 13 | super().__init__() 14 | 15 | def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor: 16 | return x 17 | 18 | 19 | class FFTExpandBigConv1d(nn.Module): 20 | # 专为小输入大卷积核设计 21 | # 输入:N, 1, l_in 22 | # 输出:N, num_experts, l_out 23 | def __init__( 24 | self, 25 | num_experts: int, 26 | input_len: int, 27 | output_len: int, 28 | ): 29 | super().__init__() 30 | self.num_experts = num_experts 31 | self.input_len = input_len 32 | self.output_len = output_len 33 | 34 | # we pad x 1% of length on each end... for bizzare reasons.. 35 | self.pad_len = math.floor((self.input_len + self.output_len - 1) / 100) 36 | self.pad_len = max(self.pad_len, 1) 37 | 38 | self.time_len = self.input_len + self.output_len - 1 + 2 * self.pad_len 39 | self.freq_len = self.time_len // 2 + 1 40 | 41 | # Initialized as Average filter 42 | self.weight = nn.Parameter( 43 | torch.zeros((1, self.num_experts, 1, self.freq_len), 44 | dtype=torch.cfloat)) 45 | self.weight.data[..., 0] = 1 46 | self.bias = nn.Parameter( 47 | torch.zeros(1, 48 | self.num_experts, 49 | 1, 50 | self.freq_len, 51 | dtype=torch.cfloat)) 52 | 53 | def forward(self, x: torch.Tensor, rank_experts=None): 54 | if rank_experts is not None: 55 | weight = self.weight * rank_experts # 1, num_experts, input_len//2+1 56 | weight = weight.sum(dim=1, keepdim=True) # 1, 1, input_len//2+1 57 | 58 | bias = self.bias * rank_experts 59 | bias = bias.sum(dim=1, keepdim=True) 60 | else: 61 | weight = self.weight 62 | bias = self.bias 63 | 64 | # input: N, 1, l_in 65 | # depth-wise convolution 66 | # pad x to match kernel 67 | 68 | x = F.pad(x, [self.pad_len, self.output_len - 1 + self.pad_len]) 69 | 70 | # calculate FFT 71 | x = torch.fft.rfft(x) 72 | # weight = torch.fft.rfft(weight) 73 | 74 | # frequency production 75 | x = x * weight 76 | 77 | # bias 78 | x = x + bias 79 | 80 | # invert FFT 81 | x = torch.fft.irfft(x, n=self.time_len) 82 | 83 | x = x[..., -self.output_len - self.pad_len:-self.pad_len] 84 | 85 | # output: N, experts, l_out if rank is None else N, 1, l_out 86 | return x 87 | 88 | 89 | class StaticTimeWeight(nn.Module): 90 | 91 | def __init__(self, input_len, num_experts): 92 | super().__init__() 93 | self.input_len = input_len 94 | self.num_experts = num_experts 95 | self.weight = nn.Parameter( 96 | torch.ones(1, self.num_experts, 1, self.input_len)) 97 | 98 | def forward(self, x, rank_experts=None): 99 | # x: N, 1, c, input_len 100 | # if rank_experts provided, output is N, 1, c, input_len//2+1 101 | # if not, output is N, num_experts, c, input_len//2+1 102 | 103 | if rank_experts is not None: 104 | weight = self.weight * rank_experts # 1, num_experts, c, input_len//2+1 105 | weight = weight.sum(dim=1, keepdim=True) # 1, 1, c, input_len//2+1 106 | else: 107 | weight = self.weight 108 | x = x * weight 109 | return x 110 | 111 | 112 | class StaticFreqWeight(nn.Module): 113 | # we do not use window function since it is a linear operation 114 | 115 | def __init__(self, input_len, num_experts): 116 | super().__init__() 117 | self.input_len = input_len 118 | self.num_experts = num_experts 119 | self.weight = nn.Parameter( 120 | torch.ones(1, self.num_experts, 1, self.input_len // 2 + 1)) 121 | 122 | def get_weight_channel(self, rank_experts): 123 | 124 | if rank_experts is not None: 125 | weight = self.weight * rank_experts # 1, num_experts, c, input_len//2+1 126 | weight = weight.sum(dim=1, keepdim=True) # 1, 1, c, input_len//2+1 127 | else: 128 | weight = self.weight 129 | return weight 130 | 131 | def forward(self, x, rank_experts=None, windowing=False): 132 | # x: N, 1, c, input_len 133 | # if rank_experts provided, output is N, 1, c, input_len//2+1 134 | # if not, output is N, num_experts, c, input_len//2+1 135 | 136 | weight = self.get_weight_channel(rank_experts) 137 | 138 | # x = F.pad(x, [self.input_len // 2, self.input_len // 2]) 139 | if windowing: 140 | window = torch.hamming_window(self.input_len, 141 | dtype=x.dtype, 142 | device=x.device) 143 | x = x * window 144 | x = torch.fft.rfft(x) 145 | x = x * weight 146 | x = torch.fft.irfft(x, n=self.input_len) 147 | if windowing: 148 | x = x / window 149 | # x = x[:, :, :, self.input_len // 2:-self.input_len // 2] 150 | 151 | return x 152 | 153 | 154 | class DiPE(nn.Module): 155 | 156 | def __init__( 157 | self, 158 | input_len: int, 159 | output_len: int, 160 | input_features: int, 161 | output_features: int, 162 | individual_f: bool = False, 163 | individual_t: bool = False, 164 | individual_c: bool = False, 165 | num_experts: int = 1, 166 | use_revin: bool = True, 167 | use_time_w: bool = True, 168 | use_freq_w: bool = True, 169 | loss_alpha: float = 0., 170 | t_loss: Literal['mse', 'mae'] = 'mse', 171 | ): 172 | super().__init__() 173 | self.input_len = input_len 174 | self.output_len = output_len 175 | self.num_features = input_features 176 | self.individual_f = individual_f 177 | self.individual_t = individual_t 178 | self.individual_c = individual_c 179 | self.num_experts = num_experts 180 | assert input_features == output_features 181 | 182 | self.use_revin = use_revin 183 | self.use_time_w = use_time_w 184 | self.use_freq_w = use_freq_w 185 | self.loss_alpha = loss_alpha 186 | self.t_loss = t_loss 187 | 188 | self.example_input_array = torch.Tensor(32, input_len, input_features) 189 | 190 | if self.num_experts > 1: 191 | self.route = nn.Parameter( 192 | torch.randn(1, num_experts, self.num_features, 1)) 193 | self.temperature = 114514 194 | self.temperature = float('nan') 195 | self.router_softmax = nn.Softmax(dim=1) 196 | # self.static_route = torch.eye(self.num_experts).unsqueeze(0).unsqueeze(-1) 197 | self.static_route = torch.eye( 198 | self.num_features).unsqueeze(0).unsqueeze(-1) 199 | 200 | if self.use_time_w: 201 | if self.individual_t: 202 | self.time_w = StaticTimeWeight(self.input_len, 203 | self.num_features) 204 | else: 205 | self.time_w = StaticTimeWeight(self.input_len, self.num_experts) 206 | else: 207 | self.time_w = Identity() 208 | 209 | if self.use_freq_w: 210 | if self.individual_f: 211 | self.freq_w = StaticFreqWeight(self.input_len, 212 | self.num_features) 213 | else: 214 | self.freq_w = StaticFreqWeight(self.input_len, self.num_experts) 215 | else: 216 | self.freq_w = Identity() 217 | 218 | if self.individual_c: 219 | self.expert = FFTExpandBigConv1d(self.num_features, self.input_len, 220 | self.output_len) 221 | else: 222 | self.expert = FFTExpandBigConv1d(self.num_experts, self.input_len, 223 | self.output_len) 224 | 225 | self.dropout = nn.Dropout(0.1) 226 | 227 | def forward(self, x): 228 | batch_size = x.shape[0] 229 | 230 | x = rearrange(x, 'n l c -> n 1 c l') 231 | 232 | if self.use_revin: 233 | 234 | x_mean = x.mean(dim=-1, keepdim=True).detach() 235 | x_std = x.std(dim=-1, keepdim=True).detach().clamp(min=1e-7) 236 | x = (x - x_mean) / x_std 237 | 238 | if self.num_experts > 1: 239 | 240 | rank_experts = self.router_softmax(self.route / 241 | self.temperature) # 1, h, c, 1 242 | 243 | else: 244 | rank_experts = None 245 | 246 | if self.individual_f: 247 | x = self.freq_w(x, self.static_route.to(x.device)) 248 | else: 249 | x = self.freq_w(x, rank_experts) 250 | x = self.dropout(x) 251 | 252 | if self.individual_t: 253 | x = self.time_w(x, self.static_route.to(x.device)) 254 | else: 255 | x = self.time_w(x, rank_experts) 256 | 257 | if self.individual_c: 258 | x = self.expert(x, self.static_route.to(x.device)) 259 | else: 260 | x = self.expert(x, rank_experts) 261 | 262 | if self.use_revin: 263 | x = x * x_std 264 | x = x + x_mean 265 | 266 | x = rearrange(x, 'n 1 c l -> n l c') 267 | 268 | return x 269 | 270 | def loss(self, y, y_hat): 271 | y = rearrange(y, 'n l c -> n c l') 272 | y_hat = rearrange(y_hat, 'n l c -> n c l') 273 | 274 | if self.t_loss == 'mse': 275 | time_loss = F.mse_loss(y, y_hat) 276 | else: 277 | time_loss = F.l1_loss(y, y_hat) 278 | 279 | if self.use_freq_w: 280 | 281 | if self.num_experts > 1: 282 | rank_experts = self.router_softmax( 283 | self.route / self.temperature) # 1, h, c, 1 284 | else: 285 | rank_experts = None 286 | 287 | if self.individual_f: 288 | rank_experts = self.static_route.to(y.device) 289 | 290 | freq_w = self.freq_w.get_weight_channel(rank_experts) 291 | freq_w = freq_w.detach() 292 | freq_w = freq_w / freq_w.mean(dim=-1, keepdim=True) 293 | 294 | if freq_w.shape[-1] != y.shape[-1] // 2 + 1: 295 | with torch.no_grad(): 296 | freq_w = torch.fft.irfft(freq_w, n=y.shape[-1]) 297 | freq_w = torch.fft.rfft(freq_w) 298 | 299 | else: 300 | freq_w = 1 301 | 302 | fft_y = torch.fft.rfft(y, norm='ortho') 303 | fft_y_hat = torch.fft.rfft(y_hat, norm='ortho') 304 | 305 | freq_loss = F.l1_loss(fft_y * freq_w, fft_y_hat * freq_w) 306 | 307 | return (1 - self.loss_alpha) * time_loss + self.loss_alpha * freq_loss 308 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ts 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - absl-py=2.1.0=pyhd8ed1ab_0 10 | - alsa-lib=1.2.12=h4ab18f5_0 11 | - aom=3.6.1=h59595ed_0 12 | - argcomplete=3.5.1=pyhd8ed1ab_0 13 | - asttokens=2.4.1=pyhd8ed1ab_0 14 | - attrs=24.2.0=pyh71513ae_0 15 | - blas=2.116=mkl 16 | - blas-devel=3.9.0=16_linux64_mkl 17 | - brotli=1.1.0=hb9d3cd8_2 18 | - brotli-bin=1.1.0=hb9d3cd8_2 19 | - brotli-python=1.1.0=py311hfdbb021_2 20 | - bzip2=1.0.8=h4bc722e_7 21 | - c-ares=1.34.1=heb4867d_0 22 | - ca-certificates=2024.8.30=hbcca054_0 23 | - cairo=1.18.0=hebfffa5_3 24 | - certifi=2024.8.30=pyhd8ed1ab_0 25 | - cffi=1.17.1=py311hf29c0ef_0 26 | - charset-normalizer=3.4.0=pyhd8ed1ab_0 27 | - colorama=0.4.6=pyhd8ed1ab_0 28 | - comm=0.2.2=pyhd8ed1ab_0 29 | - contourpy=1.3.0=py311hd18a35c_2 30 | - cpython=3.11.10=py311hd8ed1ab_2 31 | - cuda-cudart=12.1.105=0 32 | - cuda-cupti=12.1.105=0 33 | - cuda-libraries=12.1.0=0 34 | - cuda-nvrtc=12.1.105=0 35 | - cuda-nvtx=12.1.105=0 36 | - cuda-opencl=12.6.77=0 37 | - cuda-runtime=12.1.0=0 38 | - cuda-version=12.6=3 39 | - cycler=0.12.1=pyhd8ed1ab_0 40 | - cyrus-sasl=2.1.27=h54b06d7_7 41 | - dataclasses=0.8=pyhc8e2a94_3 42 | - dbus=1.13.6=h5008d03_3 43 | - debugpy=1.8.6=py311hfdbb021_0 44 | - decorator=5.1.1=pyhd8ed1ab_0 45 | - docstring_parser=0.16=pyhd8ed1ab_0 46 | - double-conversion=3.3.0=h59595ed_0 47 | - einops=0.8.0=pyhd8ed1ab_0 48 | - exceptiongroup=1.2.2=pyhd8ed1ab_0 49 | - executing=2.1.0=pyhd8ed1ab_0 50 | - expat=2.6.3=h5888daf_0 51 | - ffmpeg=4.4.2=gpl_hdf48244_113 52 | - filelock=3.16.1=pyhd8ed1ab_0 53 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 54 | - font-ttf-inconsolata=3.000=h77eed37_0 55 | - font-ttf-source-code-pro=2.038=h77eed37_0 56 | - font-ttf-ubuntu=0.83=h77eed37_3 57 | - fontconfig=2.14.2=h14ed4e7_0 58 | - fonts-conda-ecosystem=1=0 59 | - fonts-conda-forge=1=0 60 | - fonttools=4.54.1=py311h9ecbd09_0 61 | - freetype=2.12.1=h267a509_2 62 | - fsspec=2024.9.0=pyhff2d567_0 63 | - gettext=0.22.5=he02047a_3 64 | - gettext-tools=0.22.5=he02047a_3 65 | - gh=2.59.0=h76a2195_0 66 | - gmp=6.3.0=hac33072_2 67 | - gmpy2=2.1.5=py311h0f6cedb_2 68 | - gnutls=3.7.9=hb077bed_0 69 | - graphite2=1.3.13=h59595ed_1003 70 | - grpcio=1.65.5=py311h9c9ff8c_0 71 | - h2=4.1.0=pyhd8ed1ab_0 72 | - harfbuzz=9.0.0=hda332d3_1 73 | - hpack=4.0.0=pyh9f0ad1d_0 74 | - hyperframe=6.0.1=pyhd8ed1ab_0 75 | - icu=75.1=he02047a_0 76 | - idna=3.10=pyhd8ed1ab_0 77 | - importlib-metadata=8.5.0=pyha770c72_0 78 | - importlib-resources=6.4.5=pyhd8ed1ab_0 79 | - importlib_resources=6.4.5=pyhd8ed1ab_0 80 | - ipykernel=6.29.5=pyh3099207_0 81 | - ipython=8.28.0=pyh707e725_0 82 | - jedi=0.19.1=pyhd8ed1ab_0 83 | - jinja2=3.1.4=pyhd8ed1ab_0 84 | - joblib=1.4.2=pyhd8ed1ab_0 85 | - jsonargparse=4.33.2=pyhd8ed1ab_0 86 | - jsonnet=0.20.0=py311hb755f60_1 87 | - jsonschema=4.23.0=pyhd8ed1ab_0 88 | - jsonschema-specifications=2024.10.1=pyhd8ed1ab_0 89 | - jupyter_client=8.6.3=pyhd8ed1ab_0 90 | - jupyter_core=5.7.2=pyh31011fe_1 91 | - keyutils=1.6.1=h166bdaf_0 92 | - kiwisolver=1.4.7=py311hd18a35c_0 93 | - krb5=1.21.3=h659f571_0 94 | - lame=3.100=h166bdaf_1003 95 | - lcms2=2.16=hb7c19ff_0 96 | - ld_impl_linux-64=2.43=h712a8e2_1 97 | - lerc=4.0.0=h27087fc_0 98 | - libabseil=20240722.0=cxx17_h5888daf_1 99 | - libasprintf=0.22.5=he8f35ee_3 100 | - libasprintf-devel=0.22.5=he8f35ee_3 101 | - libblas=3.9.0=16_linux64_mkl 102 | - libbrotlicommon=1.1.0=hb9d3cd8_2 103 | - libbrotlidec=1.1.0=hb9d3cd8_2 104 | - libbrotlienc=1.1.0=hb9d3cd8_2 105 | - libcblas=3.9.0=16_linux64_mkl 106 | - libclang-cpp19.1=19.1.1=default_hb5137d0_0 107 | - libclang13=19.1.1=default_h9c6a7e4_0 108 | - libcublas=12.1.0.26=0 109 | - libcufft=11.0.2.4=0 110 | - libcufile=1.11.1.6=0 111 | - libcups=2.3.3=h4637d8d_4 112 | - libcurand=10.3.7.77=0 113 | - libcusolver=11.4.4.55=0 114 | - libcusparse=12.0.2.55=0 115 | - libdeflate=1.22=hb9d3cd8_0 116 | - libdrm=2.4.123=hb9d3cd8_0 117 | - libedit=3.1.20191231=he28a2e2_2 118 | - libegl=1.7.0=ha4b6fd6_1 119 | - libexpat=2.6.3=h5888daf_0 120 | - libffi=3.4.2=h7f98852_5 121 | - libgcc=14.1.0=h77fa898_1 122 | - libgcc-ng=14.1.0=h69a702a_1 123 | - libgettextpo=0.22.5=he02047a_3 124 | - libgettextpo-devel=0.22.5=he02047a_3 125 | - libgfortran=14.1.0=h69a702a_1 126 | - libgfortran-ng=14.1.0=h69a702a_1 127 | - libgfortran5=14.1.0=hc5f4f2c_1 128 | - libgl=1.7.0=ha4b6fd6_1 129 | - libglib=2.82.1=h2ff4ddf_0 130 | - libglvnd=1.7.0=ha4b6fd6_1 131 | - libglx=1.7.0=ha4b6fd6_1 132 | - libgomp=14.1.0=h77fa898_1 133 | - libgrpc=1.65.5=hf5c653b_0 134 | - libhwloc=2.11.1=default_hecaa2ac_1000 135 | - libiconv=1.17=hd590300_2 136 | - libidn2=2.3.7=hd590300_0 137 | - libjpeg-turbo=3.0.0=hd590300_1 138 | - liblapack=3.9.0=16_linux64_mkl 139 | - liblapacke=3.9.0=16_linux64_mkl 140 | - libllvm19=19.1.1=ha7bfdaf_0 141 | - libnpp=12.0.2.50=0 142 | - libnsl=2.0.1=hd590300_0 143 | - libntlm=1.4=h7f98852_1002 144 | - libnvjitlink=12.1.105=0 145 | - libnvjpeg=12.1.1.14=0 146 | - libopengl=1.7.0=ha4b6fd6_1 147 | - libpciaccess=0.18=hd590300_0 148 | - libpng=1.6.44=hadc24fc_0 149 | - libpq=17.0=h04577a9_2 150 | - libprotobuf=5.27.5=h5b01275_2 151 | - libre2-11=2023.11.01=hbbce691_0 152 | - libsodium=1.0.20=h4ab18f5_0 153 | - libsqlite=3.46.1=hadc24fc_0 154 | - libstdcxx=14.1.0=hc0a3c3a_1 155 | - libstdcxx-ng=14.1.0=h4852527_1 156 | - libtasn1=4.19.0=h166bdaf_0 157 | - libtiff=4.7.0=he137b08_1 158 | - libunistring=0.9.10=h7f98852_0 159 | - libuuid=2.38.1=h0b41bf4_0 160 | - libva=2.22.0=h8a09558_1 161 | - libvpx=1.13.1=h59595ed_0 162 | - libwebp-base=1.4.0=hd590300_0 163 | - libxcb=1.17.0=h8a09558_0 164 | - libxcrypt=4.4.36=hd590300_1 165 | - libxkbcommon=1.7.0=h2c5496b_1 166 | - libxml2=2.12.7=he7c6b58_4 167 | - libxslt=1.1.39=h76b75d6_0 168 | - libzlib=1.3.1=hb9d3cd8_2 169 | - lightning=2.4.0=pyhd8ed1ab_0 170 | - lightning-utilities=0.11.7=pyhd8ed1ab_0 171 | - llvm-openmp=15.0.7=h0cdce71_0 172 | - markdown=3.6=pyhd8ed1ab_0 173 | - markdown-it-py=3.0.0=pyhd8ed1ab_0 174 | - markupsafe=3.0.1=py311h2dc5d0c_1 175 | - matplotlib=3.9.2=py311h38be061_1 176 | - matplotlib-base=3.9.2=py311h2b939e6_1 177 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 178 | - mdurl=0.1.2=pyhd8ed1ab_0 179 | - mkl=2022.1.0=h84fe81f_915 180 | - mkl-devel=2022.1.0=ha770c72_916 181 | - mkl-include=2022.1.0=h84fe81f_915 182 | - mpc=1.3.1=h24ddda3_1 183 | - mpfr=4.2.1=h90cbb55_3 184 | - mpmath=1.3.0=pyhd8ed1ab_0 185 | - munkres=1.1.4=pyh9f0ad1d_0 186 | - mysql-common=9.0.1=h266115a_1 187 | - mysql-libs=9.0.1=he0572af_1 188 | - nbformat=5.10.4=pyhd8ed1ab_0 189 | - ncurses=6.5=he02047a_1 190 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 191 | - nettle=3.9.1=h7ab15ed_0 192 | - networkx=3.3=pyhd8ed1ab_1 193 | - numpy=2.1.2=py311h71ddf71_0 194 | - openh264=2.3.1=hcb278e6_2 195 | - openjpeg=2.5.2=h488ebb8_0 196 | - openldap=2.6.8=hedd0468_0 197 | - openssl=3.4.0=hb9d3cd8_0 198 | - p11-kit=0.24.1=hc5aa10d_0 199 | - packaging=24.1=pyhd8ed1ab_0 200 | - pandas=2.2.3=py311h7db5c69_1 201 | - parso=0.8.4=pyhd8ed1ab_0 202 | - patsy=1.0.1=pyhff2d567_0 203 | - pcre2=10.44=hba22ea6_2 204 | - pexpect=4.9.0=pyhd8ed1ab_0 205 | - pickleshare=0.7.5=py_1003 206 | - pillow=10.4.0=py311h4aec55e_1 207 | - pip=24.2=pyh8b19718_1 208 | - pixman=0.43.2=h59595ed_0 209 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1 210 | - platformdirs=4.3.6=pyhd8ed1ab_0 211 | - plotly=5.24.1=pyhd8ed1ab_0 212 | - prompt-toolkit=3.0.48=pyha770c72_0 213 | - protobuf=5.27.5=py311hfdbb021_0 214 | - psutil=6.0.0=py311h9ecbd09_1 215 | - pthread-stubs=0.4=hb9d3cd8_1002 216 | - ptyprocess=0.7.0=pyhd3deb0d_0 217 | - pure_eval=0.2.3=pyhd8ed1ab_0 218 | - pycparser=2.22=pyhd8ed1ab_0 219 | - pygments=2.18.0=pyhd8ed1ab_0 220 | - pyparsing=3.1.4=pyhd8ed1ab_0 221 | - pyside6=6.7.3=py311h9053184_1 222 | - pysocks=1.7.1=pyha2e5f31_6 223 | - python=3.11.10=hc5c86c4_2_cpython 224 | - python-dateutil=2.9.0=pyhd8ed1ab_0 225 | - python-fastjsonschema=2.20.0=pyhd8ed1ab_0 226 | - python-tzdata=2024.2=pyhd8ed1ab_0 227 | - python_abi=3.11=5_cp311 228 | - pytorch=2.4.1=py3.11_cuda12.1_cudnn9.1.0_0 229 | - pytorch-cuda=12.1=ha16c6d3_5 230 | - pytorch-lightning=2.4.0=pyhd8ed1ab_0 231 | - pytorch-mutex=1.0=cuda 232 | - pytz=2024.1=pyhd8ed1ab_0 233 | - pyyaml=6.0.2=py311h9ecbd09_1 234 | - pyzmq=26.2.0=py311h7deb3e3_2 235 | - qhull=2020.2=h434a139_5 236 | - qt6-main=6.7.3=h6e8976b_1 237 | - re2=2023.11.01=h77b4e00_0 238 | - readline=8.2=h8228510_1 239 | - referencing=0.35.1=pyhd8ed1ab_0 240 | - requests=2.32.3=pyhd8ed1ab_0 241 | - rich=13.9.2=pyhd8ed1ab_0 242 | - rpds-py=0.20.0=py311h9e33e62_1 243 | - scikit-learn=1.5.2=py311h57cc02b_1 244 | - scipy=1.14.1=py311he1f765f_0 245 | - seaborn=0.13.2=hd8ed1ab_2 246 | - seaborn-base=0.13.2=pyhd8ed1ab_2 247 | - setuptools=75.1.0=pyhd8ed1ab_0 248 | - six=1.16.0=pyh6c4a22f_0 249 | - stack_data=0.6.2=pyhd8ed1ab_0 250 | - statsmodels=0.14.4=py311h9f3472d_0 251 | - svt-av1=1.4.1=hcb278e6_0 252 | - sympy=1.13.3=pyh2585a3b_104 253 | - tbb=2021.13.0=h84d6215_0 254 | - tenacity=9.0.0=pyhd8ed1ab_0 255 | - tensorboard=2.18.0=pyhd8ed1ab_0 256 | - tensorboard-data-server=0.7.0=py311h63ff55d_1 257 | - threadpoolctl=3.5.0=pyhc1e730c_0 258 | - tk=8.6.13=noxft_h4845f30_101 259 | - tomli=2.0.2=pyhd8ed1ab_0 260 | - torchaudio=2.4.1=py311_cu121 261 | - torchmetrics=1.4.2=pyhd8ed1ab_0 262 | - torchtriton=3.0.0=py311 263 | - torchvision=0.19.1=py311_cu121 264 | - tornado=6.4.1=py311h9ecbd09_1 265 | - tqdm=4.66.5=pyhd8ed1ab_0 266 | - traitlets=5.14.3=pyhd8ed1ab_0 267 | - typeshed-client=2.4.0=pyhd8ed1ab_0 268 | - typing-extensions=4.12.2=hd8ed1ab_0 269 | - typing_extensions=4.12.2=pyha770c72_0 270 | - tzdata=2024b=hc8b5060_0 271 | - urllib3=2.2.3=pyhd8ed1ab_0 272 | - validators=0.34.0=pyhd8ed1ab_0 273 | - wayland=1.23.1=h3e06ad9_0 274 | - wayland-protocols=1.37=hd8ed1ab_0 275 | - wcwidth=0.2.13=pyhd8ed1ab_0 276 | - werkzeug=3.0.4=pyhd8ed1ab_0 277 | - wheel=0.44.0=pyhd8ed1ab_0 278 | - x264=1!164.3095=h166bdaf_2 279 | - x265=3.5=h924138e_3 280 | - xcb-util=0.4.1=hb711507_2 281 | - xcb-util-cursor=0.1.5=hb9d3cd8_0 282 | - xcb-util-image=0.4.0=hb711507_2 283 | - xcb-util-keysyms=0.4.1=hb711507_0 284 | - xcb-util-renderutil=0.3.10=hb711507_0 285 | - xcb-util-wm=0.4.2=hb711507_0 286 | - xkeyboard-config=2.43=hb9d3cd8_0 287 | - xorg-libice=1.1.1=hb9d3cd8_1 288 | - xorg-libsm=1.2.4=he73a12e_1 289 | - xorg-libx11=1.8.10=h4f16b4b_0 290 | - xorg-libxau=1.0.11=hb9d3cd8_1 291 | - xorg-libxcomposite=0.4.6=hb9d3cd8_2 292 | - xorg-libxcursor=1.2.2=hb9d3cd8_0 293 | - xorg-libxdamage=1.1.6=hb9d3cd8_0 294 | - xorg-libxdmcp=1.1.5=hb9d3cd8_0 295 | - xorg-libxext=1.3.6=hb9d3cd8_0 296 | - xorg-libxfixes=6.0.1=hb9d3cd8_0 297 | - xorg-libxi=1.8.2=hb9d3cd8_0 298 | - xorg-libxrandr=1.5.4=hb9d3cd8_0 299 | - xorg-libxrender=0.9.11=hb9d3cd8_1 300 | - xorg-libxtst=1.2.5=hb9d3cd8_3 301 | - xorg-libxxf86vm=1.1.5=hb9d3cd8_3 302 | - xorg-xorgproto=2024.1=hb9d3cd8_1 303 | - xz=5.2.6=h166bdaf_0 304 | - yaml=0.2.5=h7f98852_2 305 | - yapf=0.40.1=pyhd8ed1ab_0 306 | - zeromq=4.3.5=h3b0a872_6 307 | - zipp=3.20.2=pyhd8ed1ab_0 308 | - zlib=1.3.1=hb9d3cd8_2 309 | - zstandard=0.23.0=py311hbc35293_1 310 | - zstd=1.5.6=ha6fb4c9_0 311 | - pip: 312 | - torch-tb-profiler==0.4.3 313 | prefix: /root/miniforge3/envs/ts 314 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MAIN] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=R, 54 | abstract-method, 55 | apply-builtin, 56 | arguments-differ, 57 | attribute-defined-outside-init, 58 | backtick, 59 | bad-option-value, 60 | basestring-builtin, 61 | buffer-builtin, 62 | c-extension-no-member, 63 | consider-using-enumerate, 64 | cmp-builtin, 65 | cmp-method, 66 | coerce-builtin, 67 | coerce-method, 68 | delslice-method, 69 | div-method, 70 | eq-without-hash, 71 | execfile-builtin, 72 | file-builtin, 73 | filter-builtin-not-iterating, 74 | fixme, 75 | getslice-method, 76 | global-statement, 77 | hex-method, 78 | idiv-method, 79 | implicit-str-concat, 80 | import-error, 81 | import-self, 82 | import-star-module-level, 83 | input-builtin, 84 | intern-builtin, 85 | invalid-str-codec, 86 | locally-disabled, 87 | long-builtin, 88 | long-suffix, 89 | map-builtin-not-iterating, 90 | misplaced-comparison-constant, 91 | missing-function-docstring, 92 | metaclass-assignment, 93 | next-method-called, 94 | next-method-defined, 95 | no-absolute-import, 96 | no-init, # added 97 | no-member, 98 | no-name-in-module, 99 | no-self-use, 100 | nonzero-method, 101 | not-callable, 102 | oct-method, 103 | old-division, 104 | old-ne-operator, 105 | old-octal-literal, 106 | old-raise-syntax, 107 | parameter-unpacking, 108 | print-statement, 109 | raising-string, 110 | range-builtin-not-iterating, 111 | raw_input-builtin, 112 | rdiv-method, 113 | reduce-builtin, 114 | relative-import, 115 | reload-builtin, 116 | round-builtin, 117 | setslice-method, 118 | signature-differs, 119 | standarderror-builtin, 120 | suppressed-message, 121 | sys-max-int, 122 | trailing-newlines, 123 | unichr-builtin, 124 | unicode-builtin, 125 | unnecessary-pass, 126 | unpacking-in-except, 127 | useless-else-on-loop, 128 | useless-suppression, 129 | using-cmp-argument, 130 | wrong-import-order, 131 | xrange-builtin, 132 | zip-builtin-not-iterating, 133 | 134 | 135 | [REPORTS] 136 | 137 | # Set the output format. Available formats are text, parseable, colorized, msvs 138 | # (visual studio) and html. You can also give a reporter class, eg 139 | # mypackage.mymodule.MyReporterClass. 140 | output-format=text 141 | 142 | # Tells whether to display a full report or only the messages 143 | reports=no 144 | 145 | # Python expression which should return a note less than 10 (10 is the highest 146 | # note). You have access to the variables errors warning, statement which 147 | # respectively contain the number of errors / warnings messages and the total 148 | # number of statements analyzed. This is used by the global evaluation report 149 | # (RP0004). 150 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 151 | 152 | # Template used to display messages. This is a python new-style format string 153 | # used to format the message information. See doc for all details 154 | #msg-template= 155 | 156 | 157 | [BASIC] 158 | 159 | # Good variable names which should always be accepted, separated by a comma 160 | good-names=main,_ 161 | 162 | # Bad variable names which should always be refused, separated by a comma 163 | bad-names= 164 | 165 | # Colon-delimited sets of names that determine each other's naming style when 166 | # the name regexes allow several styles. 167 | name-group= 168 | 169 | # Include a hint for the correct naming format with invalid-name 170 | include-naming-hint=no 171 | 172 | # List of decorators that produce properties, such as abc.abstractproperty. Add 173 | # to this list to register other decorators that produce valid properties. 174 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 175 | 176 | # Regular expression matching correct function names 177 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 178 | 179 | # Regular expression matching correct variable names 180 | variable-rgx=^[a-z][a-z0-9_]*$ 181 | 182 | # Regular expression matching correct constant names 183 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 184 | 185 | # Regular expression matching correct attribute names 186 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 187 | 188 | # Regular expression matching correct argument names 189 | argument-rgx=^[a-z][a-z0-9_]*$ 190 | 191 | # Regular expression matching correct class attribute names 192 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 193 | 194 | # Regular expression matching correct inline iteration names 195 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 196 | 197 | # Regular expression matching correct class names 198 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 199 | 200 | # Regular expression matching correct module names 201 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 202 | 203 | # Regular expression matching correct method names 204 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 205 | 206 | # Regular expression which should only match function or class names that do 207 | # not require a docstring. 208 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 209 | 210 | # Minimum line length for functions/classes that require docstrings, shorter 211 | # ones are exempt. 212 | docstring-min-length=12 213 | 214 | 215 | [TYPECHECK] 216 | 217 | # List of decorators that produce context managers, such as 218 | # contextlib.contextmanager. Add to this list to register other decorators that 219 | # produce valid context managers. 220 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 221 | 222 | # List of module names for which member attributes should not be checked 223 | # (useful for modules/projects where namespaces are manipulated during runtime 224 | # and thus existing member attributes cannot be deduced by static analysis. It 225 | # supports qualified module names, as well as Unix pattern matching. 226 | ignored-modules= 227 | 228 | # List of class names for which member attributes should not be checked (useful 229 | # for classes with dynamically set attributes). This supports the use of 230 | # qualified names. 231 | ignored-classes=optparse.Values,thread._local,_thread._local 232 | 233 | # List of members which are set dynamically and missed by pylint inference 234 | # system, and so shouldn't trigger E1101 when accessed. Python regular 235 | # expressions are accepted. 236 | generated-members= 237 | 238 | 239 | [FORMAT] 240 | 241 | # Maximum number of characters on a single line. 242 | max-line-length=80 243 | 244 | # TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt 245 | # lines made too long by directives to pytype. 246 | 247 | # Regexp for a line that is allowed to be longer than the limit. 248 | ignore-long-lines=(?x)( 249 | ^\s*(\#\ )??$| 250 | ^\s*(from\s+\S+\s+)?import\s+.+$) 251 | 252 | # Allow the body of an if to be on the same line as the test if there is no 253 | # else. 254 | single-line-if-stmt=yes 255 | 256 | # Maximum number of lines in a module 257 | max-module-lines=99999 258 | 259 | # String used as indentation unit. The internal Google style guide mandates 2 260 | # spaces. Google's externaly-published style guide says 4, consistent with 261 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 262 | # projects (like TensorFlow). 263 | 264 | # NO! 4 SPACES IS GOOD CIVILIZATION! 265 | indent-string=' ' 266 | 267 | # Number of spaces of indent required inside a hanging or continued line. 268 | indent-after-paren=4 269 | 270 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 271 | expected-line-ending-format= 272 | 273 | 274 | [MISCELLANEOUS] 275 | 276 | # List of note tags to take in consideration, separated by a comma. 277 | notes=TODO 278 | 279 | 280 | [STRING] 281 | 282 | # This flag controls whether inconsistent-quotes generates a warning when the 283 | # character used as a quote delimiter is used inconsistently within a module. 284 | check-quote-consistency=yes 285 | 286 | 287 | [VARIABLES] 288 | 289 | # Tells whether we should check for unused import in __init__ files. 290 | init-import=no 291 | 292 | # A regular expression matching the name of dummy variables (i.e. expectedly 293 | # not used). 294 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 295 | 296 | # List of additional names supposed to be defined in builtins. Remember that 297 | # you should avoid to define new builtins when possible. 298 | additional-builtins= 299 | 300 | # List of strings which can identify a callback function by name. A callback 301 | # name must start or end with one of those strings. 302 | callbacks=cb_,_cb 303 | 304 | # List of qualified module names which can have objects that can redefine 305 | # builtins. 306 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 307 | 308 | 309 | [LOGGING] 310 | 311 | # Logging modules to check that the string format arguments are in logging 312 | # function parameter format 313 | logging-modules=logging,absl.logging,tensorflow.io.logging 314 | 315 | 316 | [SIMILARITIES] 317 | 318 | # Minimum lines number of a similarity. 319 | min-similarity-lines=4 320 | 321 | # Ignore comments when computing similarities. 322 | ignore-comments=yes 323 | 324 | # Ignore docstrings when computing similarities. 325 | ignore-docstrings=yes 326 | 327 | # Ignore imports when computing similarities. 328 | ignore-imports=no 329 | 330 | 331 | [SPELLING] 332 | 333 | # Spelling dictionary name. Available dictionaries: none. To make it working 334 | # install python-enchant package. 335 | spelling-dict= 336 | 337 | # List of comma separated words that should not be checked. 338 | spelling-ignore-words= 339 | 340 | # A path to a file that contains private dictionary; one word per line. 341 | spelling-private-dict-file= 342 | 343 | # Tells whether to store unknown words to indicated private dictionary in 344 | # --spelling-private-dict-file option instead of raising a message. 345 | spelling-store-unknown-words=no 346 | 347 | 348 | [IMPORTS] 349 | 350 | # Deprecated modules which should not be used, separated by a comma 351 | deprecated-modules=regsub, 352 | TERMIOS, 353 | Bastion, 354 | rexec, 355 | sets 356 | 357 | # Create a graph of every (i.e. internal and external) dependencies in the 358 | # given file (report RP0402 must not be disabled) 359 | import-graph= 360 | 361 | # Create a graph of external dependencies in the given file (report RP0402 must 362 | # not be disabled) 363 | ext-import-graph= 364 | 365 | # Create a graph of internal dependencies in the given file (report RP0402 must 366 | # not be disabled) 367 | int-import-graph= 368 | 369 | # Force import order to recognize a module as part of the standard 370 | # compatibility libraries. 371 | known-standard-library= 372 | 373 | # Force import order to recognize a module as part of a third party library. 374 | known-third-party=enchant, absl 375 | 376 | # Analyse import fallback blocks. This can be used to support both Python 2 and 377 | # 3 compatible code, which means that the block might have code that exists 378 | # only in one or another interpreter, leading to false positives when analysed. 379 | analyse-fallback-blocks=no 380 | 381 | 382 | [CLASSES] 383 | 384 | # List of method names used to declare (i.e. assign) instance attributes. 385 | defining-attr-methods=__init__, 386 | __new__, 387 | setUp 388 | 389 | # List of member names, which should be excluded from the protected access 390 | # warning. 391 | exclude-protected=_asdict, 392 | _fields, 393 | _replace, 394 | _source, 395 | _make 396 | 397 | # List of valid names for the first argument in a class method. 398 | valid-classmethod-first-arg=cls, 399 | class_ 400 | 401 | # List of valid names for the first argument in a metaclass class method. 402 | valid-metaclass-classmethod-first-arg=mcs 403 | --------------------------------------------------------------------------------