├── .gitignore
├── LICENSE
├── README.md
├── Tutorial.md
├── data_provider
├── __init__.py
├── anomaly_detection.yaml
├── data_factory.py
├── data_loader.py
├── fewshot_new_task.yaml
├── imputation.yaml
├── m4.py
├── multi_task.yaml
├── multi_task_pretrain.yaml
├── multitask_zero_shot_new_length.yaml
├── uea.py
└── zeroshot_task.yaml
├── download_data_all.sh
├── exp
├── __init__.py
├── exp_pretrain.py
└── exp_sup.py
├── models
├── UniTS.py
├── UniTS_zeroshot.py
└── __init__.py
├── requirements.txt
├── run.py
├── run_pretrain.py
├── scripts
├── few_shot_anomaly_detection
│ ├── UniTS_finetune_few_shot_anomaly_detection.sh
│ └── UniTS_prompt_tuning_few_shot_anomaly_detection.sh
├── few_shot_imputation
│ ├── UniTS_finetune_few_shot_imputation_mask025.sh
│ ├── UniTS_finetune_few_shot_imputation_mask050.sh
│ ├── UniTS_prompt_tuning_few_shot_imputation_mask025.sh
│ └── UniTS_prompt_tuning_few_shot_imputation_mask050.sh
├── few_shot_newdata
│ ├── UniTS_finetune_few_shot_newdata_pct05.sh
│ ├── UniTS_finetune_few_shot_newdata_pct15.sh
│ ├── UniTS_finetune_few_shot_newdata_pct20.sh
│ ├── UniTS_prompt_tuning_few_shot_newdata_pct05.sh
│ ├── UniTS_prompt_tuning_few_shot_newdata_pct15.sh
│ └── UniTS_prompt_tuning_few_shot_newdata_pct20.sh
├── pretrain_prompt_learning
│ ├── UniTS_pretrain_x128.sh
│ ├── UniTS_pretrain_x32.sh
│ └── UniTS_pretrain_x64.sh
├── supervised_learning
│ └── UniTS_supervised.sh
└── zero_shot
│ ├── UniTS_forecast_new_length_unify.sh
│ └── UniTS_zeroshot_newdata.sh
└── utils
├── __init__.py
├── dataloader.py
├── ddp.py
├── layer_decay.py
├── losses.py
├── m4_summary.py
├── masking.py
├── metrics.py
├── timefeatures.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | /scripts/long_term_forecast/Traffic_script/PatchTST1.sh
131 | /backups/
132 | /result.xlsx
133 | /~$result.xlsx
134 | /Time-Series-Library.zip
135 | /temp.sh
136 |
137 | .idea
138 | /tv_result.xlsx
139 | /test.py
140 | /m4_results/
141 | /test_results/
142 | /PatchTST_results.xlsx
143 | /seq_len_long_term_forecast/
144 | /progress.xlsx
145 | /scripts/short_term_forecast/PatchTST_M4.sh
146 | /run_tv.py
147 |
148 | /scripts/long_term_forecast/ETT_tv_script/
149 | /dataset/
150 | data_factory_all.py
151 | data_loader_all.py
152 | /scripts/short_term_forecast/tv_script/
153 | /exp/exp_short_term_forecasting_tv.py
154 | /exp/exp_long_term_forecasting_tv.py
155 | /timesnetv2.xlsx
156 | /scripts/anomaly_detection/tmp/
157 | /scripts/imputation/tmp/
158 | /utils/self_tools.py
159 |
160 | checkpoints
161 | logs
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unified Time Series Model
2 |
3 | [**Project Page**](https://zitniklab.hms.harvard.edu/projects/UniTS/) | [**Paper link**](https://arxiv.org/pdf/2403.00131.pdf) **(Neurips 2024)**
4 |
5 | UniTS is a unified time series model that can process various tasks across multiple domains with shared parameters and does not have any task-specific modules.
6 |
7 | Authors: [Shanghua Gao](https://shgao.site/) [Teddy Koker](https://teddykoker.com) [Owen Queen](https://owencqueen.github.io/) [Thomas Hartvigsen](https://www.tomhartvigsen.com/) [Theodoros Tsiligkaridis](https://sites.google.com/view/theo-t) [Marinka Zitnik](https://zitniklab.hms.harvard.edu/)
8 |
9 | ## Overview
10 | Foundation models, especially LLMs, are profoundly transforming deep learning. Instead of training many task-specific models, we can adapt a single pretrained model to many tasks via few-shot prompting or fine-tuning. However, current foundation models apply to sequence data but not to time series, which present unique challenges due to the inherent diverse and multi-domain time series datasets, diverging task specifications across forecasting, classification and other types of tasks, and the apparent need for task-specialized models.
11 |
12 | We developed UniTS, a unified time series model that supports a universal task specification, accommodating classification, forecasting, imputation, and anomaly detection tasks. This is achieved through a novel unified network backbone, which incorporates sequence and variable attention along with a dynamic linear operator and is trained as a unified model.
13 |
14 | Across 38 multi-domain datasets, UniTS demonstrates superior performance compared to task-specific models and repurposed natural language-based LLMs. UniTS exhibits remarkable zero-shot, few-shot, and prompt learning capabilities when evaluated on new data domains and tasks.
15 |
16 |
17 |
18 |
19 |
20 | ## Setups
21 |
22 | ### 1. Requirements
23 | Install Pytorch2.0+ and the required packages.
24 | ```
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ### 2. Prepare data
29 | ```
30 | bash download_data_all.sh
31 | ```
32 | Datasets configs for different multi-task settings are shown in `.ymal` files of the `data_provider` folder.
33 |
34 | By default, all experiments follow the multi-task setting where one UniTS model is jointly trained on mulitple datasets.
35 |
36 | ### 3. Train and evaluate model
37 |
38 | #### 1. Multi-task learning on forecasting and classification tasks:
39 |
40 | - Pretraining + Prompt learning
41 | ```
42 | bash ./scripts/pretrain_prompt_learning/UniTS_pretrain_x128.sh
43 | ```
44 |
45 | - Supervised learning
46 | ```
47 | bash ./scripts/supervised_learning/UniTS_supervised.sh
48 | ```
49 |
50 | #### 2. Few-shot transfer learning on new forecasting and classification tasks:
51 |
52 | **Note: Please follow the instruction in following training scripts to get the pretrained ckpt first.**
53 |
54 | - Finetuning
55 | ```
56 | # please set the pretrianed model path in the script.
57 | bash ./scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct20.sh
58 | ```
59 |
60 | - Prompt tuning
61 | ```
62 | # please set the pretrianed model path in the script.
63 | bash ./scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct20.sh
64 | ```
65 |
66 | #### 3. Few-shot transfer learning on anomaly detection tasks:
67 | - Finetuning
68 | ```
69 | # please set the pretrianed model path in the script.
70 | bash ./scripts/few_shot_anomaly_detection/UniTS_finetune_few_shot_anomaly_detection.sh
71 | ```
72 | - Prompt tuning
73 | ```
74 | # please set the pretrianed model path in the script.
75 | bash ./scripts/few_shot_anomaly_detection/UniTS_prompt_tuning_few_shot_anomaly_detection.sh
76 | ```
77 |
78 | #### 4. Few-shot transfer learning on imputation tasks:
79 | - Finetuning
80 | ```
81 | # please set the pretrianed model path in the script.
82 | bash ./scripts/few_shot_imputation/UniTS_finetune_few_shot_imputation_mask050.sh
83 | ```
84 |
85 | - Prompt tuning
86 | ```
87 | # please set the pretrianed model path in the script.
88 | bash ./scripts/few_shot_imputation/UniTS_prompt_tuning_few_shot_imputation_mask050.sh
89 | ```
90 |
91 | #### 5. Zero-shot learning on new forecasting length:
92 | ```
93 | # please set the pretrianed model path in the script.
94 | bash ./scripts/zero_shot/UniTS_forecast_new_length_unify.sh
95 | ```
96 |
97 | #### 6. Zero-shot learning on new forecasting datasets:
98 | ```
99 | # A special verison of UniTS with shared prompt/mask tokens needs to be trained for this setting.
100 | bash ./scripts/zero_shot/UniTS_zeroshot_newdata.sh
101 | ```
102 |
103 | ## Use UniTS on your own data.
104 | UniTS is a highly flexible unified time series model, supporting tasks such as forecasting, classification, imputation, and anomaly detection with a single shared model and shared weights. We provide a [Tutorial](Tutorial.md) to assist you in using your own data with UniTS.
105 |
106 | ## Pretrained weights
107 | We provide the pretrained weights for models mentioned above in [checkpoints](https://github.com/mims-harvard/UniTS/releases/tag/ckpt).
108 |
109 | ## Citation
110 |
111 | ```
112 | @article{gao2024building,
113 | title={UniTS: Building a Unified Time Series Model},
114 | author={Gao, Shanghua and Koker, Teddy and Queen, Owen and Hartvigsen, Thomas and Tsiligkaridis, Theodoros and Zitnik, Marinka},
115 | journal={arXiv},
116 | url={https://arxiv.org/pdf/2403.00131.pdf},
117 | year={2024}
118 | }
119 | ```
120 |
121 | ## Acknowledgement
122 | This codebase is built based on the [Time-Series-Library](https://github.com/thuml/Time-Series-Library). Thanks!
123 |
124 | ## Disclaimer
125 |
126 | DISTRIBUTION STATEMENT: Approved for public release. Distribution is unlimited.
127 |
128 | This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.
129 |
130 | © 2024 Massachusetts Institute of Technology.
131 |
132 | Subject to FAR52.227-11 Patent Rights - Ownership by the contractor (May 2014)
133 |
134 | The software/firmware is provided to you on an As-Is basis
135 |
136 | Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other than as specifically authorized by the U.S. Government may violate any copyrights that exist in this work.
137 |
--------------------------------------------------------------------------------
/Tutorial.md:
--------------------------------------------------------------------------------
1 | # Quick start for using UniTS on your own data.
2 |
3 | ## Classficiation with your own data.
4 |
5 | We use a classification task as an example. The primary difference for other tasks lies in the data formats. You can follow the provided dataset as a guide to adapt your own data.
6 |
7 | ### 1. Prepare data
8 |
9 | We support common data formats of time series datasets.
10 |
11 | You can follow the [dataset format guide](https://www.aeon-toolkit.org/en/latest/examples/datasets/data_loading.html) to transfer your dataset into `.ts` format dataset.
12 |
13 | The dataset should contain `newdata_TRAIN.ts` and `newdata_TEST.ts` files.
14 |
15 | ### 2. Define the dataset config file
16 |
17 | To support multiple datasets, our code base uses the `data_set.yaml` to keep the dataset information.
18 | Examples can be found in `data_provider` folder.
19 |
20 | Here is an example for classification dataset. You can add multiple dataset config in one config file if you want to make UniTS support multiple datasets.
21 | ```yaml
22 | task_dataset:
23 | CLS_ECG5000: # the dataset and task name
24 | task_name: classification # the type of task
25 | dataset: ECG5000 # the name of the dataset
26 | data: UEA # the data type of the dataset, use UEA if you use the '.ts' file
27 | embed: timeF # the embedding method used
28 | root_path: ../dataset/UCR/ECG5000 # the root path of the dataset
29 | seq_len: 140 # the length of the input sequence
30 | label_len: 0 # the length of the label sequence, 0 for classification
31 | pred_len: 0 # the length of the predicted sequence, 0 for classification
32 | enc_in: 1 # the number of variable numbers
33 | num_class: 5 # the number of classes
34 | c_out: None # the output variable numbers, 0 for classification
35 | ```
36 |
37 | ### 3. Finetune your UniTS model
38 |
39 | #### Load Pretrained weights (Optional)
40 | You can load the pretrained SSL/Supervised UniTS model.
41 | Run [SSL Pretraining]() or [Supervised training]() scripts to get the pretrained checkpoints.
42 | Normally, SSL pretrained model has better transfer learning abilities.
43 |
44 | #### Setup finetuning script
45 |
46 | **Note: Remove captions before using the following scripts!**
47 |
48 | - Finetuning/Supervised training
49 | ```bash
50 | model_name=UniTS # Model name, UniTS
51 | exp_name=UniTS_supervised_x64 # Exp name
52 | wandb_mode=online # Use wandb to log the training, change to disabled if you don't want to use it
53 | project_name=supervised_learning # preject name in wandb
54 |
55 | random_port=$((RANDOM % 9000 + 1000))
56 |
57 | # Supervised learning
58 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
59 | --is_training 1 \ # 1 for training, 0 for testing
60 | --model_id $exp_name \
61 | --model $model_name \
62 | --lradj supervised \ # You can define your own lr decay scheme in the adjust_learning_rate function of utils/tools.py
63 | --prompt_num 10 \ # The number of prompt tokens.
64 | --patch_len 16 \ # Patch size for each token in UniTS
65 | --stride 16 \ # Stride = patch size
66 | --e_layers 3 \
67 | --d_model 64 \
68 | --des 'Exp' \
69 | --learning_rate 1e-4 \ # Tune the following hp for your datasets. Due to the high deverse nature of time series data, you might need to tune the hp for your new data.
70 | --weight_decay 5e-6 \
71 | --train_epochs 5 \
72 | --batch_size 32 \ # Real batch size = batch_size * acc_it
73 | --acc_it 32 \
74 | --debug $wandb_mode \
75 | --project_name $project_name \
76 | --clip_grad 100 \ # Grad clip to avoid Nan.
77 | --pretrained_weight ckpt_path.pth \ # Path of pretrained ckpt if you want to finetune the model, otherwise just remove it
78 | --task_data_config_path data_provider/multi_task.yaml # Important: Change to your_own_data_config.yaml
79 |
80 | ```
81 |
82 | - Prompt learning
83 |
84 | For prompt learning, only tokens are finetuned and the model are fixed.
85 | **You must load pretrained model weights.**
86 | ```bash
87 | # Prompt tuning
88 | torchrun --nnodes 1 --master_port $random_port run.py \
89 | --is_training 1 \
90 | --model_id $exp_name \
91 | --model $model_name \
92 | --lradj prompt_tuning \
93 | --prompt_num 10 \
94 | --patch_len 16 \
95 | --stride 16 \
96 | --e_layers 3 \
97 | --d_model $d_model \
98 | --des 'Exp' \
99 | --itr 1 \
100 | --learning_rate 3e-3 \
101 | --weight_decay 0 \
102 | --prompt_tune_epoch 2 \ # Number of epochs for prompt tuning
103 | --train_epochs 0 \
104 | --acc_it 32 \
105 | --debug $wandb_mode \
106 | --project_name $ptune_name \
107 | --clip_grad 100 \
108 | --pretrained_weight auto \ # Path of pretrained ckpt, you must add it for prompt learning
109 | --task_data_config_path data_provider/multi_task.yaml # Important: Change to your_own_data_config.yaml
110 | ```
111 |
112 | ###
113 | Feel free to open an issue if you have any problems in using our code.
114 |
115 | This doc will be updated.
--------------------------------------------------------------------------------
/data_provider/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/data_provider/__init__.py
--------------------------------------------------------------------------------
/data_provider/anomaly_detection.yaml:
--------------------------------------------------------------------------------
1 | task_dataset:
2 | MSL:
3 | task_name: anomaly_detection
4 | dataset_name: MSL
5 | dataset: MSL
6 | data: MSL
7 | root_path: ../dataset/MSL
8 | seq_len: 96
9 | label_len: 0
10 | pred_len: 0
11 | features: M
12 | embed: timeF
13 | enc_in: 55
14 | dec_in: 55
15 | c_out: 55
16 |
17 | PSM:
18 | task_name: anomaly_detection
19 | dataset_name: PSM
20 | dataset: PSM
21 | data: PSM
22 | root_path: ../dataset/PSM
23 | seq_len: 96
24 | label_len: 0
25 | pred_len: 0
26 | features: M
27 | embed: timeF
28 | enc_in: 25
29 | dec_in: 25
30 | c_out: 25
31 |
32 | SMAP:
33 | task_name: anomaly_detection
34 | dataset_name: SMAP
35 | dataset: SMAP
36 | data: SMAP
37 | root_path: ../dataset/SMAP
38 | seq_len: 96
39 | label_len: 0
40 | pred_len: 0
41 | features: M
42 | embed: timeF
43 | enc_in: 25
44 | dec_in: 25
45 | c_out: 25
46 |
47 | SMD:
48 | task_name: anomaly_detection
49 | dataset_name: SMD
50 | dataset: SMD
51 | data: SMD
52 | root_path: ../dataset/SMD
53 | seq_len: 96
54 | label_len: 0
55 | pred_len: 0
56 | features: M
57 | embed: timeF
58 | enc_in: 38
59 | dec_in: 38
60 | c_out: 38
61 |
62 | SWAT:
63 | task_name: anomaly_detection
64 | dataset_name: SWAT
65 | dataset: SWAT
66 | data: SWAT
67 | root_path: ../dataset/SWaT
68 | seq_len: 96
69 | label_len: 0
70 | pred_len: 0
71 | features: M
72 | embed: timeF
73 | enc_in: 51
74 | dec_in: 51
75 | c_out: 51
--------------------------------------------------------------------------------
/data_provider/data_factory.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, PSMSegLoader, \
2 | MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, GLUONTSDataset
3 | from data_provider.uea import collate_fn
4 | import torch
5 | from torch.utils.data import DataLoader, Subset
6 | from torch.utils.data.distributed import DistributedSampler
7 |
8 | data_dict = {
9 | 'ETTh1': Dataset_ETT_hour,
10 | 'ETTh2': Dataset_ETT_hour,
11 | 'ETTm1': Dataset_ETT_minute,
12 | 'ETTm2': Dataset_ETT_minute,
13 | 'custom': Dataset_Custom,
14 | # 'm4': Dataset_M4, Removed due to the LICENSE file constraints of m4.py
15 | 'PSM': PSMSegLoader,
16 | 'MSL': MSLSegLoader,
17 | 'SMAP': SMAPSegLoader,
18 | 'SMD': SMDSegLoader,
19 | 'SWAT': SWATSegLoader,
20 | 'UEA': UEAloader,
21 | # datasets from gluonts package:
22 | "gluonts": GLUONTSDataset,
23 | }
24 |
25 |
26 | def random_subset(dataset, pct, seed):
27 | generator = torch.Generator()
28 | generator.manual_seed(seed)
29 | idx = torch.randperm(len(dataset), generator=generator)
30 | return Subset(dataset, idx[:int(len(dataset) * pct)].long().numpy())
31 |
32 |
33 | def data_provider(args, config, flag, ddp=False): # args,
34 | Data = data_dict[config['data']]
35 | timeenc = 0 if config['embed'] != 'timeF' else 1
36 |
37 | if flag == 'test':
38 | shuffle_flag = False
39 | drop_last = False
40 | if 'anomaly_detection' in config['task_name']: # working on one gpu
41 | batch_size = args.batch_size
42 | else:
43 | batch_size = 1 # bsz=1 for evaluation
44 | freq = args.freq
45 | else:
46 | shuffle_flag = True
47 | drop_last = True
48 | batch_size = args.batch_size # bsz for train and valid
49 | freq = args.freq
50 |
51 | if 'gluonts' in config['data']:
52 | # process gluonts dataset:
53 | data_set = Data(
54 | dataset_name=config['dataset_name'],
55 | size=(config['seq_len'], config['label_len'], config['pred_len']),
56 | path=config['root_path'],
57 | # Don't set dataset_writer
58 | features=config["features"],
59 | flag=flag,
60 | )
61 | if args.subsample_pct is not None and flag == "train":
62 | data_set = random_subset(
63 | data_set, args.subsample_pct, args.fix_seed)
64 |
65 | data_loader = DataLoader(
66 | data_set,
67 | batch_size=batch_size,
68 | shuffle=shuffle_flag,
69 | num_workers=args.num_workers,
70 | drop_last=drop_last
71 | )
72 |
73 | return data_set, data_loader
74 |
75 | timeenc = 0 if config['embed'] != 'timeF' else 1
76 |
77 | if 'anomaly_detection' in config['task_name']:
78 | drop_last = False
79 | data_set = Data(
80 | root_path=config['root_path'],
81 | win_size=config['seq_len'],
82 | flag=flag,
83 | )
84 | if args.subsample_pct is not None and flag == "train":
85 | data_set = random_subset(
86 | data_set, args.subsample_pct, args.fix_seed)
87 | print("ddp mode is set to false for anomaly_detection", ddp, len(data_set))
88 | data_loader = DataLoader(
89 | data_set,
90 | batch_size=batch_size,
91 | shuffle=False if ddp else shuffle_flag,
92 | num_workers=args.num_workers,
93 | sampler=DistributedSampler(data_set) if ddp else None,
94 | drop_last=drop_last)
95 | return data_set, data_loader
96 | elif 'classification' in config['task_name']:
97 | drop_last = False
98 | data_set = Data(
99 | root_path=config['root_path'],
100 | flag=flag,
101 | )
102 | if args.subsample_pct is not None and flag == "train":
103 | data_set = random_subset(
104 | data_set, args.subsample_pct, args.fix_seed)
105 | print(flag, len(data_set))
106 | data_loader = DataLoader(
107 | data_set,
108 | batch_size=batch_size,
109 | shuffle=False if ddp else shuffle_flag,
110 | num_workers=args.num_workers,
111 | drop_last=drop_last,
112 | sampler=DistributedSampler(data_set) if ddp else None,
113 | collate_fn=lambda x: collate_fn(x, max_len=config['seq_len'])
114 | )
115 | return data_set, data_loader
116 | else:
117 | if config['data'] == 'm4':
118 | drop_last = False
119 | data_set = Data(
120 | root_path=config['root_path'],
121 | data_path=config['data_path'],
122 | flag=flag,
123 | size=[config['seq_len'], config['label_len'], config['pred_len']],
124 | features=config['features'],
125 | target=args.target,
126 | timeenc=timeenc,
127 | freq=freq,
128 | seasonal_patterns=config['seasonal_patterns'] if config['data'] == 'm4' else None
129 | )
130 | if args.subsample_pct is not None and flag == "train":
131 | data_set = random_subset(
132 | data_set, args.subsample_pct, args.fix_seed)
133 | print(flag, len(data_set))
134 | data_loader = DataLoader(
135 | data_set,
136 | batch_size=batch_size,
137 | shuffle=False if ddp else shuffle_flag,
138 | num_workers=args.num_workers,
139 | sampler=DistributedSampler(data_set) if ddp else None,
140 | drop_last=drop_last)
141 | return data_set, data_loader
142 |
--------------------------------------------------------------------------------
/data_provider/fewshot_new_task.yaml:
--------------------------------------------------------------------------------
1 | task_dataset:
2 |
3 | CLS_ECG200:
4 | task_name: classification
5 | dataset: ECG200
6 | data: UEA
7 | embed: timeF
8 | root_path: ../dataset/ECG200
9 | seq_len: 96
10 | label_len: 0
11 | pred_len: 0
12 | enc_in: 1
13 | num_class: 2
14 | c_out: None
15 |
16 | CLS_Handwriting:
17 | task_name: classification
18 | dataset: Handwriting
19 | data: UEA
20 | embed: timeF
21 | root_path: ../dataset/Handwriting
22 | seq_len: 152
23 | label_len: 0
24 | pred_len: 0
25 | enc_in: 3
26 | num_class: 26
27 | c_out: None
28 |
29 | CLS_SelfRegulationSCP1:
30 | task_name: classification
31 | dataset: SelfRegulationSCP1
32 | data: UEA
33 | embed: timeF
34 | root_path: ../dataset/SelfRegulationSCP1
35 | seq_len: 896
36 | label_len: 0
37 | pred_len: 0
38 | enc_in: 6
39 | num_class: 2
40 | c_out: None
41 |
42 |
43 | CLS_RacketSports:
44 | task_name: classification
45 | dataset: RacketSports
46 | data: UEA
47 | embed: timeF
48 | root_path: ../dataset/UAE/RacketSports
49 | seq_len: 30
50 | label_len: 0
51 | pred_len: 0
52 | enc_in: 6
53 | num_class: 4
54 | c_out: None
55 |
56 | CLS_Epilepsy:
57 | task_name: classification
58 | dataset: Epilepsy
59 | data: UEA
60 | embed: timeF
61 | root_path: ../dataset/UAE/Epilepsy
62 | seq_len: 207
63 | label_len: 0
64 | pred_len: 0
65 | enc_in: 3
66 | num_class: 4
67 | c_out: None
68 |
69 | CLS_StarLightCurves:
70 | task_name: classification
71 | dataset: StarLightCurves
72 | data: UEA
73 | embed: timeF
74 | root_path: ../dataset/UCR/StarLightCurves
75 | seq_len: 1024
76 | label_len: 0
77 | pred_len: 0
78 | enc_in: 1
79 | num_class: 3
80 | c_out: None
81 |
82 | LTF_ETTh2_p96:
83 | task_name: long_term_forecast
84 | dataset: ETTh2
85 | data: ETTh2
86 | embed: timeF
87 | root_path: ../dataset/ETT-small/
88 | data_path: ETTh1.csv
89 | features: M
90 | seq_len: 96
91 | label_len: 48
92 | pred_len: 96
93 | enc_in: 7
94 | dec_in: 7
95 | c_out: 7
96 |
97 | LTF_ETTh2_p192:
98 | task_name: long_term_forecast
99 | dataset: ETTh2
100 | data: ETTh2
101 | embed: timeF
102 | root_path: ../dataset/ETT-small/
103 | data_path: ETTh2.csv
104 | features: M
105 | seq_len: 96
106 | label_len: 48
107 | pred_len: 192
108 | enc_in: 7
109 | dec_in: 7
110 | c_out: 7
111 |
112 | LTF_ETTh2_p336:
113 | task_name: long_term_forecast
114 | dataset: ETTh2
115 | data: ETTh2
116 | embed: timeF
117 | root_path: ../dataset/ETT-small/
118 | data_path: ETTh2.csv
119 | features: M
120 | seq_len: 96
121 | label_len: 48
122 | pred_len: 336
123 | enc_in: 7
124 | dec_in: 7
125 | c_out: 7
126 |
127 | LTF_ETTh2_p720:
128 | task_name: long_term_forecast
129 | dataset: ETTh2
130 | data: ETTh2
131 | embed: timeF
132 | root_path: ../dataset/ETT-small/
133 | data_path: ETTh2.csv
134 | features: M
135 | seq_len: 96
136 | label_len: 48
137 | pred_len: 720
138 | enc_in: 7
139 | dec_in: 7
140 | c_out: 7
141 |
142 | LTF_SaugeenRiverFlow:
143 | task_name: long_term_forecast
144 | dataset_name: saugeenday
145 | dataset: SaugeenRiverFlow
146 | data: gluonts
147 | root_path: ../dataset/gluonts
148 | seq_len: 48
149 | label_len: 0
150 | pred_len: 24
151 | features: M
152 | embed: timeF
153 | enc_in: 1
154 | dec_in: 1
155 | c_out: 1
156 |
157 | LTF_ETTm1_p96:
158 | task_name: long_term_forecast
159 | dataset: ETTm1
160 | data: ETTm1
161 | embed: timeF
162 | root_path: ../dataset/ETT-small/
163 | data_path: ETTm1.csv
164 | features: M
165 | seq_len: 96
166 | label_len: 48
167 | pred_len: 96
168 | enc_in: 7
169 | dec_in: 7
170 | c_out: 7
171 |
172 | LTF_ETTm1_p192:
173 | task_name: long_term_forecast
174 | dataset: ETTm1
175 | data: ETTm1
176 | embed: timeF
177 | root_path: ../dataset/ETT-small/
178 | data_path: ETTm1.csv
179 | features: M
180 | seq_len: 96
181 | label_len: 48
182 | pred_len: 192
183 | enc_in: 7
184 | dec_in: 7
185 | c_out: 7
186 |
187 | LTF_ETTm1_p336:
188 | task_name: long_term_forecast
189 | dataset: ETTm1
190 | data: ETTm1
191 | embed: timeF
192 | root_path: ../dataset/ETT-small/
193 | data_path: ETTm1.csv
194 | features: M
195 | seq_len: 96
196 | label_len: 48
197 | pred_len: 336
198 | enc_in: 7
199 | dec_in: 7
200 | c_out: 7
201 |
202 | LTF_ETTm1_p720:
203 | task_name: long_term_forecast
204 | dataset: ETTm1
205 | data: ETTm1
206 | embed: timeF
207 | root_path: ../dataset/ETT-small/
208 | data_path: ETTm1.csv
209 | features: M
210 | seq_len: 96
211 | label_len: 48
212 | pred_len: 720
213 | enc_in: 7
214 | dec_in: 7
215 | c_out: 7
216 |
217 |
--------------------------------------------------------------------------------
/data_provider/imputation.yaml:
--------------------------------------------------------------------------------
1 | task_dataset:
2 | LTF_ECL_p96:
3 | task_name: imputation
4 | dataset: ECL
5 | data: custom
6 | embed: timeF
7 | root_path: ../dataset/electricity/
8 | data_path: electricity.csv
9 | features: M
10 | seq_len: 96
11 | label_len: 0
12 | pred_len: 0
13 | enc_in: 321
14 | dec_in: 321
15 | c_out: 321
16 |
17 | LTF_ETTh1_p96:
18 | task_name: imputation
19 | dataset: ETTh1
20 | data: ETTh1
21 | embed: timeF
22 | root_path: ../dataset/ETT-small/
23 | data_path: ETTh1.csv
24 | features: M
25 | seq_len: 96
26 | label_len: 0
27 | pred_len: 0
28 | enc_in: 7
29 | dec_in: 7
30 | c_out: 7
31 |
32 | LTF_Weather_p96:
33 | task_name: imputation
34 | dataset: Weather
35 | data: custom
36 | embed: timeF
37 | root_path: ../dataset/weather/
38 | data_path: weather.csv
39 | features: M
40 | seq_len: 96
41 | label_len: 0
42 | pred_len: 0
43 | enc_in: 21
44 | dec_in: 21
45 | c_out: 21
46 |
47 | LTF_ETTh2_p96:
48 | task_name: imputation
49 | dataset: ETTh2
50 | data: ETTh2
51 | embed: timeF
52 | root_path: ../dataset/ETT-small/
53 | data_path: ETTh2.csv
54 | features: M
55 | seq_len: 96
56 | label_len: 0
57 | pred_len: 0
58 | enc_in: 7
59 | dec_in: 7
60 | c_out: 7
61 |
62 | LTF_ETTm1_p96:
63 | task_name: imputation
64 | dataset: ETTm1
65 | data: ETTm1
66 | embed: timeF
67 | root_path: ../dataset/ETT-small/
68 | data_path: ETTm1.csv
69 | features: M
70 | seq_len: 96
71 | label_len: 0
72 | pred_len: 0
73 | enc_in: 7
74 | dec_in: 7
75 | c_out: 7
76 |
77 | LTF_ETTm2_p96:
78 | task_name: imputation
79 | dataset: ETTm2
80 | data: ETTm2
81 | embed: timeF
82 | root_path: ../dataset/ETT-small/
83 | data_path: ETTm2.csv
84 | features: M
85 | seq_len: 96
86 | label_len: 0
87 | pred_len: 0
88 | enc_in: 7
89 | dec_in: 7
90 | c_out: 7
91 |
--------------------------------------------------------------------------------
/data_provider/m4.py:
--------------------------------------------------------------------------------
1 | # This file is removed due to LICENSE file constraints.
2 | # You can copy the m4.py from https://github.com/thuml/Time-Series-Library/blob/main/data_provider/m4.py
--------------------------------------------------------------------------------
/data_provider/multi_task.yaml:
--------------------------------------------------------------------------------
1 | task_dataset:
2 | NN5_p112:
3 | task_name: long_term_forecast
4 | dataset_name: nn5_daily_without_missing
5 | dataset: NN5
6 | data: gluonts
7 | root_path: ../dataset/gluonts
8 | seq_len: 112
9 | label_len: 0
10 | pred_len: 112
11 | features: M
12 | embed: timeF
13 | enc_in: 111
14 | dec_in: 111
15 | c_out: 111
16 |
17 | LTF_ECL_p96:
18 | task_name: long_term_forecast
19 | dataset: ECL
20 | data: custom
21 | embed: timeF
22 | root_path: ../dataset/electricity/
23 | data_path: electricity.csv
24 | features: M
25 | seq_len: 96
26 | label_len: 48
27 | pred_len: 96
28 | enc_in: 321
29 | dec_in: 321
30 | c_out: 321
31 |
32 | LTF_ECL_p192:
33 | task_name: long_term_forecast
34 | dataset: ECL
35 | data: custom
36 | embed: timeF
37 | root_path: ../dataset/electricity/
38 | data_path: electricity.csv
39 | features: M
40 | seq_len: 96
41 | label_len: 48
42 | pred_len: 192
43 | enc_in: 321
44 | dec_in: 321
45 | c_out: 321
46 |
47 | LTF_ECL_p336:
48 | task_name: long_term_forecast
49 | dataset: ECL
50 | data: custom
51 | embed: timeF
52 | root_path: ../dataset/electricity/
53 | data_path: electricity.csv
54 | features: M
55 | seq_len: 96
56 | label_len: 48
57 | pred_len: 336
58 | enc_in: 321
59 | dec_in: 321
60 | c_out: 321
61 |
62 | LTF_ECL_p720:
63 | task_name: long_term_forecast
64 | dataset: ECL
65 | data: custom
66 | embed: timeF
67 | root_path: ../dataset/electricity/
68 | data_path: electricity.csv
69 | features: M
70 | seq_len: 96
71 | label_len: 48
72 | pred_len: 720
73 | enc_in: 321
74 | dec_in: 321
75 | c_out: 321
76 |
77 | LTF_ETTh1_p96:
78 | task_name: long_term_forecast
79 | dataset: ETTh1
80 | data: ETTh1
81 | embed: timeF
82 | root_path: ../dataset/ETT-small/
83 | data_path: ETTh1.csv
84 | features: M
85 | seq_len: 96
86 | label_len: 48
87 | pred_len: 96
88 | enc_in: 7
89 | dec_in: 7
90 | c_out: 7
91 |
92 | LTF_ETTh1_p192:
93 | task_name: long_term_forecast
94 | dataset: ETTh1
95 | data: ETTh1
96 | embed: timeF
97 | root_path: ../dataset/ETT-small/
98 | data_path: ETTh1.csv
99 | features: M
100 | seq_len: 96
101 | label_len: 48
102 | pred_len: 192
103 | enc_in: 7
104 | dec_in: 7
105 | c_out: 7
106 |
107 | LTF_ETTh1_p336:
108 | task_name: long_term_forecast
109 | dataset: ETTh1
110 | data: ETTh1
111 | embed: timeF
112 | root_path: ../dataset/ETT-small/
113 | data_path: ETTh1.csv
114 | features: M
115 | seq_len: 96
116 | label_len: 48
117 | pred_len: 336
118 | enc_in: 7
119 | dec_in: 7
120 | c_out: 7
121 |
122 | LTF_ETTh1_p720:
123 | task_name: long_term_forecast
124 | dataset: ETTh1
125 | data: ETTh1
126 | embed: timeF
127 | root_path: ../dataset/ETT-small/
128 | data_path: ETTh1.csv
129 | features: M
130 | seq_len: 96
131 | label_len: 48
132 | pred_len: 720
133 | enc_in: 7
134 | dec_in: 7
135 | c_out: 7
136 |
137 | LTF_Exchange_p192:
138 | task_name: long_term_forecast
139 | dataset: Exchange
140 | data: custom
141 | embed: timeF
142 | root_path: ../dataset/exchange_rate/
143 | data_path: exchange_rate.csv
144 | features: M
145 | seq_len: 96
146 | label_len: 48
147 | pred_len: 192
148 | enc_in: 8
149 | dec_in: 8
150 | c_out: 8
151 |
152 | LTF_Exchange_p336:
153 | task_name: long_term_forecast
154 | dataset: Exchange
155 | data: custom
156 | embed: timeF
157 | root_path: ../dataset/exchange_rate/
158 | data_path: exchange_rate.csv
159 | features: M
160 | seq_len: 96
161 | label_len: 48
162 | pred_len: 336
163 | enc_in: 8
164 | dec_in: 8
165 | c_out: 8
166 |
167 | LTF_ILI_p60:
168 | task_name: long_term_forecast
169 | dataset: ILI
170 | data: custom
171 | embed: timeF
172 | root_path: ../dataset/illness/
173 | data_path: national_illness.csv
174 | features: M
175 | seq_len: 36
176 | label_len: 18
177 | pred_len: 60
178 | enc_in: 7
179 | dec_in: 7
180 | c_out: 7
181 |
182 | LTF_Traffic_p96:
183 | task_name: long_term_forecast
184 | dataset: Traffic
185 | data: custom
186 | embed: timeF
187 | root_path: ../dataset/traffic/
188 | data_path: traffic.csv
189 | features: M
190 | seq_len: 96
191 | label_len: 48
192 | pred_len: 96
193 | enc_in: 862
194 | dec_in: 862
195 | c_out: 862
196 |
197 | LTF_Traffic_p192:
198 | task_name: long_term_forecast
199 | dataset: Traffic
200 | data: custom
201 | embed: timeF
202 | root_path: ../dataset/traffic/
203 | data_path: traffic.csv
204 | features: M
205 | seq_len: 96
206 | label_len: 48
207 | pred_len: 192
208 | enc_in: 862
209 | dec_in: 862
210 | c_out: 862
211 |
212 | LTF_Traffic_p336:
213 | task_name: long_term_forecast
214 | dataset: Traffic
215 | data: custom
216 | embed: timeF
217 | root_path: ../dataset/traffic/
218 | data_path: traffic.csv
219 | features: M
220 | seq_len: 96
221 | label_len: 48
222 | pred_len: 336
223 | enc_in: 862
224 | dec_in: 862
225 | c_out: 862
226 |
227 | LTF_Traffic_p720:
228 | task_name: long_term_forecast
229 | dataset: Traffic
230 | data: custom
231 | embed: timeF
232 | root_path: ../dataset/traffic/
233 | data_path: traffic.csv
234 | features: M
235 | seq_len: 96
236 | label_len: 48
237 | pred_len: 720
238 | enc_in: 862
239 | dec_in: 862
240 | c_out: 862
241 |
242 | LTF_Weather_p96:
243 | task_name: long_term_forecast
244 | dataset: Weather
245 | data: custom
246 | embed: timeF
247 | root_path: ../dataset/weather/
248 | data_path: weather.csv
249 | features: M
250 | seq_len: 96
251 | label_len: 48
252 | pred_len: 96
253 | enc_in: 21
254 | dec_in: 21
255 | c_out: 21
256 |
257 | LTF_Weather_p192:
258 | task_name: long_term_forecast
259 | dataset: Weather
260 | data: custom
261 | embed: timeF
262 | root_path: ../dataset/weather/
263 | data_path: weather.csv
264 | features: M
265 | seq_len: 96
266 | label_len: 48
267 | pred_len: 192
268 | enc_in: 21
269 | dec_in: 21
270 | c_out: 21
271 |
272 | LTF_Weather_p336:
273 | task_name: long_term_forecast
274 | dataset: Weather
275 | data: custom
276 | embed: timeF
277 | root_path: ../dataset/weather/
278 | data_path: weather.csv
279 | features: M
280 | seq_len: 96
281 | label_len: 48
282 | pred_len: 336
283 | enc_in: 21
284 | dec_in: 21
285 | c_out: 21
286 |
287 | LTF_Weather_p720:
288 | task_name: long_term_forecast
289 | dataset: Weather
290 | data: custom
291 | embed: timeF
292 | root_path: ../dataset/weather/
293 | data_path: weather.csv
294 | features: M
295 | seq_len: 96
296 | label_len: 48
297 | pred_len: 720
298 | enc_in: 21
299 | dec_in: 21
300 | c_out: 21
301 |
302 | CLS_Heartbeat:
303 | task_name: classification
304 | dataset: Heartbeat
305 | data: UEA
306 | embed: timeF
307 | root_path: ../dataset/Heartbeat/
308 | seq_len: 405
309 | label_len: 0
310 | pred_len: 0
311 | enc_in: 61
312 | num_class: 2
313 | c_out: None
314 |
315 | CLS_JapaneseVowels:
316 | task_name: classification
317 | dataset: JapaneseVowels
318 | data: UEA
319 | embed: timeF
320 | root_path: ../dataset/JapaneseVowels/
321 | seq_len: 29
322 | label_len: 0
323 | pred_len: 0
324 | enc_in: 12
325 | num_class: 9
326 | c_out: None
327 |
328 | CLS_PEMS-SF:
329 | task_name: classification
330 | dataset: PEMS-SF
331 | data: UEA
332 | embed: timeF
333 | root_path: ../dataset/PEMS-SF/
334 | seq_len: 144
335 | label_len: 0
336 | pred_len: 0
337 | enc_in: 963
338 | num_class: 7
339 | c_out: None
340 |
341 | CLS_SelfRegulationSCP2:
342 | task_name: classification
343 | dataset: SelfRegulationSCP2
344 | data: UEA
345 | embed: timeF
346 | root_path: ../dataset/SelfRegulationSCP2/
347 | seq_len: 1152
348 | label_len: 0
349 | pred_len: 0
350 | enc_in: 7
351 | num_class: 2
352 | c_out: None
353 |
354 | CLS_SpokenArabicDigits:
355 | task_name: classification
356 | dataset: SpokenArabicDigits
357 | data: UEA
358 | embed: timeF
359 | root_path: ../dataset/SpokenArabicDigits/
360 | seq_len: 93
361 | label_len: 0
362 | pred_len: 0
363 | enc_in: 13
364 | num_class: 10
365 | c_out: None
366 |
367 | CLS_UWaveGestureLibrary:
368 | task_name: classification
369 | dataset: UWaveGestureLibrary
370 | data: UEA
371 | embed: timeF
372 | root_path: ../dataset/UWaveGestureLibrary/
373 | seq_len: 315
374 | label_len: 0
375 | pred_len: 0
376 | enc_in: 3
377 | num_class: 8
378 | c_out: None
379 |
380 | CLS_ECG5000:
381 | task_name: classification
382 | dataset: ECG5000
383 | data: UEA
384 | embed: timeF
385 | root_path: ../dataset/UCR/ECG5000
386 | seq_len: 140
387 | label_len: 0
388 | pred_len: 0
389 | enc_in: 1
390 | num_class: 5
391 | c_out: None
392 |
393 | CLS_NonInvasiveFetalECGThorax1:
394 | task_name: classification
395 | dataset: NonInvasiveFetalECGThorax1
396 | data: UEA
397 | embed: timeF
398 | root_path: ../dataset/UCR/NonInvasiveFetalECGThorax1
399 | seq_len: 750
400 | label_len: 0
401 | pred_len: 0
402 | enc_in: 1
403 | num_class: 52
404 | c_out: None
405 |
406 | CLS_Blink:
407 | task_name: classification
408 | dataset: Blink
409 | data: UEA
410 | embed: timeF
411 | root_path: ../dataset/Blink
412 | seq_len: 510
413 | label_len: 0
414 | pred_len: 0
415 | enc_in: 4
416 | num_class: 2
417 | c_out: None
418 |
419 | CLS_FaceDetection:
420 | task_name: classification
421 | dataset: FaceDetection
422 | data: UEA
423 | embed: timeF
424 | root_path: ../dataset/FaceDetection
425 | seq_len: 62
426 | label_len: 0
427 | pred_len: 0
428 | enc_in: 144
429 | num_class: 2
430 | c_out: None
431 |
432 | CLS_ElectricDevices:
433 | task_name: classification
434 | dataset: ElectricDevices
435 | data: UEA
436 | embed: timeF
437 | root_path: ../dataset/UCR/ElectricDevices
438 | seq_len: 96
439 | label_len: 0
440 | pred_len: 0
441 | enc_in: 1
442 | num_class: 7
443 | c_out: None
444 |
445 | CLS_Trace:
446 | task_name: classification
447 | dataset: Trace
448 | data: UEA
449 | embed: timeF
450 | root_path: ../dataset/UCR/Trace
451 | seq_len: 275
452 | label_len: 0
453 | pred_len: 0
454 | enc_in: 1
455 | num_class: 4
456 | c_out: None
457 |
458 | CLS_FordB:
459 | task_name: classification
460 | dataset: FordB
461 | data: UEA
462 | embed: timeF
463 | root_path: ../dataset/UCR/FordB
464 | seq_len: 500
465 | label_len: 0
466 | pred_len: 0
467 | enc_in: 1
468 | num_class: 2
469 | c_out: None
470 |
471 | CLS_MotionSenseHAR:
472 | task_name: classification
473 | dataset: MotionSenseHAR
474 | data: UEA
475 | embed: timeF
476 | root_path: ../dataset/MotionSenseHAR
477 | seq_len: 200
478 | label_len: 0
479 | pred_len: 0
480 | enc_in: 12
481 | num_class: 6
482 | c_out: None
483 |
484 | CLS_EMOPain:
485 | task_name: classification
486 | dataset: EMOPain
487 | data: UEA
488 | embed: timeF
489 | root_path: ../dataset/EMOPain
490 | seq_len: 180
491 | label_len: 0
492 | pred_len: 0
493 | enc_in: 30
494 | num_class: 3
495 | c_out: None
496 |
497 | CLS_Chinatown:
498 | task_name: classification
499 | dataset: Chinatown
500 | data: UEA
501 | embed: timeF
502 | root_path: ../dataset/UCR/Chinatown
503 | seq_len: 24
504 | label_len: 0
505 | pred_len: 0
506 | enc_in: 1
507 | num_class: 2
508 | c_out: None
509 |
510 | CLS_MelbournePedestrian:
511 | task_name: classification
512 | dataset: MelbournePedestrian
513 | data: UEA
514 | embed: timeF
515 | root_path: ../dataset/UCR/MelbournePedestrian
516 | seq_len: 24
517 | label_len: 0
518 | pred_len: 0
519 | enc_in: 1
520 | num_class: 10
521 | c_out: None
522 |
523 | CLS_SharePriceIncrease:
524 | task_name: classification
525 | dataset: SharePriceIncrease
526 | data: UEA
527 | embed: timeF
528 | root_path: ../dataset/SharePriceIncrease
529 | seq_len: 60
530 | label_len: 0
531 | pred_len: 0
532 | enc_in: 1
533 | num_class: 2
534 | c_out: None
--------------------------------------------------------------------------------
/data_provider/multi_task_pretrain.yaml:
--------------------------------------------------------------------------------
1 | task_dataset:
2 | NN5_p112:
3 | task_name: pretrain_long_term_forecast
4 | dataset_name: nn5_daily_without_missing
5 | dataset: NN5
6 | data: gluonts
7 | root_path: ../dataset/gluonts
8 | seq_len: 224
9 | label_len: 0
10 | pred_len: 0
11 | features: M
12 | embed: timeF
13 | enc_in: 111
14 | dec_in: 111
15 | c_out: 111
16 |
17 | LTF_ECL_p96:
18 | task_name: pretrain_long_term_forecast
19 | dataset: ECL
20 | data: custom
21 | embed: timeF
22 | root_path: ../dataset/electricity/
23 | data_path: electricity.csv
24 | features: M
25 | seq_len: 192
26 | label_len: 48
27 | pred_len: 0
28 | enc_in: 321
29 | dec_in: 321
30 | c_out: 321
31 |
32 | LTF_ECL_p192:
33 | task_name: pretrain_long_term_forecast
34 | dataset: ECL
35 | data: custom
36 | embed: timeF
37 | root_path: ../dataset/electricity/
38 | data_path: electricity.csv
39 | features: M
40 | seq_len: 288
41 | label_len: 48
42 | pred_len: 0
43 | enc_in: 321
44 | dec_in: 321
45 | c_out: 321
46 |
47 | LTF_ECL_p336:
48 | task_name: pretrain_long_term_forecast
49 | dataset: ECL
50 | data: custom
51 | embed: timeF
52 | root_path: ../dataset/electricity/
53 | data_path: electricity.csv
54 | features: M
55 | seq_len: 432
56 | label_len: 48
57 | pred_len: 0
58 | enc_in: 321
59 | dec_in: 321
60 | c_out: 321
61 |
62 | LTF_ECL_p720:
63 | task_name: pretrain_long_term_forecast
64 | dataset: ECL
65 | data: custom
66 | embed: timeF
67 | root_path: ../dataset/electricity/
68 | data_path: electricity.csv
69 | features: M
70 | seq_len: 816
71 | label_len: 48
72 | pred_len: 0
73 | enc_in: 321
74 | dec_in: 321
75 | c_out: 321
76 |
77 | LTF_ETTh1_p96:
78 | task_name: pretrain_long_term_forecast
79 | dataset: ETTh1
80 | data: ETTh1
81 | embed: timeF
82 | root_path: ../dataset/ETT-small/
83 | data_path: ETTh1.csv
84 | features: M
85 | seq_len: 192
86 | label_len: 48
87 | pred_len: 0
88 | enc_in: 7
89 | dec_in: 7
90 | c_out: 7
91 |
92 | LTF_ETTh1_p192:
93 | task_name: pretrain_long_term_forecast
94 | dataset: ETTh1
95 | data: ETTh1
96 | embed: timeF
97 | root_path: ../dataset/ETT-small/
98 | data_path: ETTh1.csv
99 | features: M
100 | seq_len: 288
101 | label_len: 48
102 | pred_len: 0
103 | enc_in: 7
104 | dec_in: 7
105 | c_out: 7
106 |
107 | LTF_ETTh1_p336:
108 | task_name: pretrain_long_term_forecast
109 | dataset: ETTh1
110 | data: ETTh1
111 | embed: timeF
112 | root_path: ../dataset/ETT-small/
113 | data_path: ETTh1.csv
114 | features: M
115 | seq_len: 432
116 | label_len: 48
117 | pred_len: 0
118 | enc_in: 7
119 | dec_in: 7
120 | c_out: 7
121 |
122 | LTF_ETTh1_p720:
123 | task_name: pretrain_long_term_forecast
124 | dataset: ETTh1
125 | data: ETTh1
126 | embed: timeF
127 | root_path: ../dataset/ETT-small/
128 | data_path: ETTh1.csv
129 | features: M
130 | seq_len: 816
131 | label_len: 48
132 | pred_len: 0
133 | enc_in: 7
134 | dec_in: 7
135 | c_out: 7
136 |
137 | LTF_Exchange_p192:
138 | task_name: pretrain_long_term_forecast
139 | dataset: Exchange
140 | data: custom
141 | embed: timeF
142 | root_path: ../dataset/exchange_rate/
143 | data_path: exchange_rate.csv
144 | features: M
145 | seq_len: 288
146 | label_len: 48
147 | pred_len: 0
148 | enc_in: 8
149 | dec_in: 8
150 | c_out: 8
151 |
152 | LTF_Exchange_p336:
153 | task_name: pretrain_long_term_forecast
154 | dataset: Exchange
155 | data: custom
156 | embed: timeF
157 | root_path: ../dataset/exchange_rate/
158 | data_path: exchange_rate.csv
159 | features: M
160 | seq_len: 432
161 | label_len: 48
162 | pred_len: 0
163 | enc_in: 8
164 | dec_in: 8
165 | c_out: 8
166 |
167 | LTF_ILI_p60:
168 | task_name: pretrain_long_term_forecast
169 | dataset: ILI
170 | data: custom
171 | embed: timeF
172 | root_path: ../dataset/illness/
173 | data_path: national_illness.csv
174 | features: M
175 | seq_len: 96
176 | label_len: 18
177 | pred_len: 0
178 | enc_in: 7
179 | dec_in: 7
180 | c_out: 7
181 |
182 | LTF_Traffic_p96:
183 | task_name: pretrain_long_term_forecast
184 | dataset: Traffic
185 | data: custom
186 | embed: timeF
187 | root_path: ../dataset/traffic/
188 | data_path: traffic.csv
189 | features: M
190 | seq_len: 192
191 | label_len: 48
192 | pred_len: 0
193 | enc_in: 862
194 | dec_in: 862
195 | c_out: 862
196 |
197 | LTF_Traffic_p192:
198 | task_name: pretrain_long_term_forecast
199 | dataset: Traffic
200 | data: custom
201 | embed: timeF
202 | root_path: ../dataset/traffic/
203 | data_path: traffic.csv
204 | features: M
205 | seq_len: 288
206 | label_len: 48
207 | pred_len: 0
208 | enc_in: 862
209 | dec_in: 862
210 | c_out: 862
211 |
212 | LTF_Traffic_p336:
213 | task_name: pretrain_long_term_forecast
214 | dataset: Traffic
215 | data: custom
216 | embed: timeF
217 | root_path: ../dataset/traffic/
218 | data_path: traffic.csv
219 | features: M
220 | seq_len: 432
221 | label_len: 48
222 | pred_len: 0
223 | enc_in: 862
224 | dec_in: 862
225 | c_out: 862
226 |
227 | LTF_Traffic_p720:
228 | task_name: pretrain_long_term_forecast
229 | dataset: Traffic
230 | data: custom
231 | embed: timeF
232 | root_path: ../dataset/traffic/
233 | data_path: traffic.csv
234 | features: M
235 | seq_len: 816
236 | label_len: 48
237 | pred_len: 0
238 | enc_in: 862
239 | dec_in: 862
240 | c_out: 862
241 |
242 | LTF_Weather_p96:
243 | task_name: pretrain_long_term_forecast
244 | dataset: Weather
245 | data: custom
246 | embed: timeF
247 | root_path: ../dataset/weather/
248 | data_path: weather.csv
249 | features: M
250 | seq_len: 192
251 | label_len: 48
252 | pred_len: 0
253 | enc_in: 21
254 | dec_in: 21
255 | c_out: 21
256 |
257 | LTF_Weather_p192:
258 | task_name: pretrain_long_term_forecast
259 | dataset: Weather
260 | data: custom
261 | embed: timeF
262 | root_path: ../dataset/weather/
263 | data_path: weather.csv
264 | features: M
265 | seq_len: 288
266 | label_len: 48
267 | pred_len: 0
268 | enc_in: 21
269 | dec_in: 21
270 | c_out: 21
271 |
272 | LTF_Weather_p336:
273 | task_name: pretrain_long_term_forecast
274 | dataset: Weather
275 | data: custom
276 | embed: timeF
277 | root_path: ../dataset/weather/
278 | data_path: weather.csv
279 | features: M
280 | seq_len: 432
281 | label_len: 48
282 | pred_len: 0
283 | enc_in: 21
284 | dec_in: 21
285 | c_out: 21
286 |
287 | LTF_Weather_p720:
288 | task_name: pretrain_long_term_forecast
289 | dataset: Weather
290 | data: custom
291 | embed: timeF
292 | root_path: ../dataset/weather/
293 | data_path: weather.csv
294 | features: M
295 | seq_len: 816
296 | label_len: 48
297 | pred_len: 0
298 | enc_in: 21
299 | dec_in: 21
300 | c_out: 21
301 |
302 | CLS_Heartbeat:
303 | task_name: pretrain_classification
304 | dataset: Heartbeat
305 | data: UEA
306 | embed: timeF
307 | root_path: ../dataset/Heartbeat/
308 | seq_len: 405
309 | label_len: 0
310 | pred_len: 0
311 | enc_in: 61
312 | num_class: 2
313 | c_out: None
314 |
315 | CLS_JapaneseVowels:
316 | task_name: pretrain_classification
317 | dataset: JapaneseVowels
318 | data: UEA
319 | embed: timeF
320 | root_path: ../dataset/JapaneseVowels/
321 | seq_len: 29
322 | label_len: 0
323 | pred_len: 0
324 | enc_in: 12
325 | num_class: 9
326 | c_out: None
327 |
328 | CLS_PEMS-SF:
329 | task_name: pretrain_classification
330 | dataset: PEMS-SF
331 | data: UEA
332 | embed: timeF
333 | root_path: ../dataset/PEMS-SF/
334 | seq_len: 144
335 | label_len: 0
336 | pred_len: 0
337 | enc_in: 963
338 | num_class: 7
339 | c_out: None
340 |
341 | CLS_SelfRegulationSCP2:
342 | task_name: pretrain_classification
343 | dataset: SelfRegulationSCP2
344 | data: UEA
345 | embed: timeF
346 | root_path: ../dataset/SelfRegulationSCP2/
347 | seq_len: 1152
348 | label_len: 0
349 | pred_len: 0
350 | enc_in: 7
351 | num_class: 2
352 | c_out: None
353 |
354 | CLS_SpokenArabicDigits:
355 | task_name: pretrain_classification
356 | dataset: SpokenArabicDigits
357 | data: UEA
358 | embed: timeF
359 | root_path: ../dataset/SpokenArabicDigits/
360 | seq_len: 93
361 | label_len: 0
362 | pred_len: 0
363 | enc_in: 13
364 | num_class: 10
365 | c_out: None
366 |
367 | CLS_UWaveGestureLibrary:
368 | task_name: pretrain_classification
369 | dataset: UWaveGestureLibrary
370 | data: UEA
371 | embed: timeF
372 | root_path: ../dataset/UWaveGestureLibrary/
373 | seq_len: 315
374 | label_len: 0
375 | pred_len: 0
376 | enc_in: 3
377 | num_class: 8
378 | c_out: None
379 |
380 | CLS_ECG5000:
381 | task_name: pretrain_classification
382 | dataset: ECG5000
383 | data: UEA
384 | embed: timeF
385 | root_path: ../dataset/UCR/ECG5000
386 | seq_len: 140
387 | label_len: 0
388 | pred_len: 0
389 | enc_in: 1
390 | num_class: 5
391 | c_out: None
392 |
393 | CLS_NonInvasiveFetalECGThorax1:
394 | task_name: pretrain_classification
395 | dataset: NonInvasiveFetalECGThorax1
396 | data: UEA
397 | embed: timeF
398 | root_path: ../dataset/UCR/NonInvasiveFetalECGThorax1
399 | seq_len: 750
400 | label_len: 0
401 | pred_len: 0
402 | enc_in: 1
403 | num_class: 52
404 | c_out: None
405 |
406 | CLS_Blink:
407 | task_name: pretrain_classification
408 | dataset: Blink
409 | data: UEA
410 | embed: timeF
411 | root_path: ../dataset/Blink
412 | seq_len: 510
413 | label_len: 0
414 | pred_len: 0
415 | enc_in: 4
416 | num_class: 2
417 | c_out: None
418 |
419 | CLS_FaceDetection:
420 | task_name: pretrain_classification
421 | dataset: FaceDetection
422 | data: UEA
423 | embed: timeF
424 | root_path: ../dataset/FaceDetection
425 | seq_len: 62
426 | label_len: 0
427 | pred_len: 0
428 | enc_in: 144
429 | num_class: 2
430 | c_out: None
431 |
432 | CLS_ElectricDevices:
433 | task_name: pretrain_classification
434 | dataset: ElectricDevices
435 | data: UEA
436 | embed: timeF
437 | root_path: ../dataset/UCR/ElectricDevices
438 | seq_len: 96
439 | label_len: 0
440 | pred_len: 0
441 | enc_in: 1
442 | num_class: 7
443 | c_out: None
444 |
445 | CLS_Trace:
446 | task_name: pretrain_classification
447 | dataset: Trace
448 | data: UEA
449 | embed: timeF
450 | root_path: ../dataset/UCR/Trace
451 | seq_len: 275
452 | label_len: 0
453 | pred_len: 0
454 | enc_in: 1
455 | num_class: 4
456 | c_out: None
457 |
458 | CLS_FordB:
459 | task_name: pretrain_classification
460 | dataset: FordB
461 | data: UEA
462 | embed: timeF
463 | root_path: ../dataset/UCR/FordB
464 | seq_len: 500
465 | label_len: 0
466 | pred_len: 0
467 | enc_in: 1
468 | num_class: 2
469 | c_out: None
470 |
471 | CLS_MotionSenseHAR:
472 | task_name: pretrain_classification
473 | dataset: MotionSenseHAR
474 | data: UEA
475 | embed: timeF
476 | root_path: ../dataset/MotionSenseHAR
477 | seq_len: 200
478 | label_len: 0
479 | pred_len: 0
480 | enc_in: 12
481 | num_class: 6
482 | c_out: None
483 |
484 | CLS_EMOPain:
485 | task_name: pretrain_classification
486 | dataset: EMOPain
487 | data: UEA
488 | embed: timeF
489 | root_path: ../dataset/EMOPain
490 | seq_len: 180
491 | label_len: 0
492 | pred_len: 0
493 | enc_in: 30
494 | num_class: 3
495 | c_out: None
496 |
497 | CLS_Chinatown:
498 | task_name: pretrain_classification
499 | dataset: Chinatown
500 | data: UEA
501 | embed: timeF
502 | root_path: ../dataset/UCR/Chinatown
503 | seq_len: 24
504 | label_len: 0
505 | pred_len: 0
506 | enc_in: 1
507 | num_class: 2
508 | c_out: None
509 |
510 | CLS_MelbournePedestrian:
511 | task_name: pretrain_classification
512 | dataset: MelbournePedestrian
513 | data: UEA
514 | embed: timeF
515 | root_path: ../dataset/UCR/MelbournePedestrian
516 | seq_len: 24
517 | label_len: 0
518 | pred_len: 0
519 | enc_in: 1
520 | num_class: 10
521 | c_out: None
522 |
523 | CLS_SharePriceIncrease:
524 | task_name: pretrain_classification
525 | dataset: SharePriceIncrease
526 | data: UEA
527 | embed: timeF
528 | root_path: ../dataset/SharePriceIncrease
529 | seq_len: 60
530 | label_len: 0
531 | pred_len: 0
532 | enc_in: 1
533 | num_class: 2
534 | c_out: None
--------------------------------------------------------------------------------
/data_provider/multitask_zero_shot_new_length.yaml:
--------------------------------------------------------------------------------
1 | task_dataset:
2 | LTF_ECL_p96:
3 | task_name: long_term_forecast
4 | dataset: ECL
5 | data: custom
6 | embed: timeF
7 | root_path: ../dataset/electricity/
8 | data_path: electricity.csv
9 | features: M
10 | seq_len: 96
11 | label_len: 48
12 | pred_len: 96
13 | enc_in: 321
14 | dec_in: 321
15 | c_out: 321
16 |
17 | LTF_ECL_p192:
18 | task_name: long_term_forecast
19 | dataset: ECL
20 | data: custom
21 | embed: timeF
22 | root_path: ../dataset/electricity/
23 | data_path: electricity.csv
24 | features: M
25 | seq_len: 96
26 | label_len: 48
27 | pred_len: 192
28 | enc_in: 321
29 | dec_in: 321
30 | c_out: 321
31 |
32 | LTF_ECL_p336:
33 | task_name: long_term_forecast
34 | dataset: ECL
35 | data: custom
36 | embed: timeF
37 | root_path: ../dataset/electricity/
38 | data_path: electricity.csv
39 | features: M
40 | seq_len: 96
41 | label_len: 48
42 | pred_len: 336
43 | enc_in: 321
44 | dec_in: 321
45 | c_out: 321
46 |
47 | LTF_ETTh1_p96:
48 | task_name: long_term_forecast
49 | dataset: ETTh1
50 | data: ETTh1
51 | embed: timeF
52 | root_path: ../dataset/ETT-small/
53 | data_path: ETTh1.csv
54 | features: M
55 | seq_len: 96
56 | label_len: 48
57 | pred_len: 96
58 | enc_in: 7
59 | dec_in: 7
60 | c_out: 7
61 |
62 | LTF_ETTh1_p192:
63 | task_name: long_term_forecast
64 | dataset: ETTh1
65 | data: ETTh1
66 | embed: timeF
67 | root_path: ../dataset/ETT-small/
68 | data_path: ETTh1.csv
69 | features: M
70 | seq_len: 96
71 | label_len: 48
72 | pred_len: 192
73 | enc_in: 7
74 | dec_in: 7
75 | c_out: 7
76 |
77 | LTF_ETTh1_p336:
78 | task_name: long_term_forecast
79 | dataset: ETTh1
80 | data: ETTh1
81 | embed: timeF
82 | root_path: ../dataset/ETT-small/
83 | data_path: ETTh1.csv
84 | features: M
85 | seq_len: 96
86 | label_len: 48
87 | pred_len: 336
88 | enc_in: 7
89 | dec_in: 7
90 | c_out: 7
91 |
92 | LTF_Exchange_p192:
93 | task_name: long_term_forecast
94 | dataset: Exchange
95 | data: custom
96 | embed: timeF
97 | root_path: ../dataset/exchange_rate/
98 | data_path: exchange_rate.csv
99 | features: M
100 | seq_len: 96
101 | label_len: 48
102 | pred_len: 192
103 | enc_in: 8
104 | dec_in: 8
105 | c_out: 8
106 |
107 | LTF_Exchange_p336:
108 | task_name: long_term_forecast
109 | dataset: Exchange
110 | data: custom
111 | embed: timeF
112 | root_path: ../dataset/exchange_rate/
113 | data_path: exchange_rate.csv
114 | features: M
115 | seq_len: 96
116 | label_len: 48
117 | pred_len: 336
118 | enc_in: 8
119 | dec_in: 8
120 | c_out: 8
121 |
122 | LTF_Traffic_p96:
123 | task_name: long_term_forecast
124 | dataset: Traffic
125 | data: custom
126 | embed: timeF
127 | root_path: ../dataset/traffic/
128 | data_path: traffic.csv
129 | features: M
130 | seq_len: 96
131 | label_len: 48
132 | pred_len: 96
133 | enc_in: 862
134 | dec_in: 862
135 | c_out: 862
136 |
137 | LTF_Traffic_p192:
138 | task_name: long_term_forecast
139 | dataset: Traffic
140 | data: custom
141 | embed: timeF
142 | root_path: ../dataset/traffic/
143 | data_path: traffic.csv
144 | features: M
145 | seq_len: 96
146 | label_len: 48
147 | pred_len: 192
148 | enc_in: 862
149 | dec_in: 862
150 | c_out: 862
151 |
152 | LTF_Traffic_p336:
153 | task_name: long_term_forecast
154 | dataset: Traffic
155 | data: custom
156 | embed: timeF
157 | root_path: ../dataset/traffic/
158 | data_path: traffic.csv
159 | features: M
160 | seq_len: 96
161 | label_len: 48
162 | pred_len: 336
163 | enc_in: 862
164 | dec_in: 862
165 | c_out: 862
166 |
167 | LTF_Weather_p96:
168 | task_name: long_term_forecast
169 | dataset: Weather
170 | data: custom
171 | embed: timeF
172 | root_path: ../dataset/weather/
173 | data_path: weather.csv
174 | features: M
175 | seq_len: 96
176 | label_len: 48
177 | pred_len: 96
178 | enc_in: 21
179 | dec_in: 21
180 | c_out: 21
181 |
182 | LTF_Weather_p192:
183 | task_name: long_term_forecast
184 | dataset: Weather
185 | data: custom
186 | embed: timeF
187 | root_path: ../dataset/weather/
188 | data_path: weather.csv
189 | features: M
190 | seq_len: 96
191 | label_len: 48
192 | pred_len: 192
193 | enc_in: 21
194 | dec_in: 21
195 | c_out: 21
196 |
197 | LTF_Weather_p336:
198 | task_name: long_term_forecast
199 | dataset: Weather
200 | data: custom
201 | embed: timeF
202 | root_path: ../dataset/weather/
203 | data_path: weather.csv
204 | features: M
205 | seq_len: 96
206 | label_len: 48
207 | pred_len: 336
208 | enc_in: 21
209 | dec_in: 21
210 | c_out: 21
--------------------------------------------------------------------------------
/data_provider/uea.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 |
6 |
7 | def collate_fn(data, max_len=None):
8 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create
9 | Args:
10 | data: len(batch_size) list of tuples (X, y).
11 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length.
12 | - y: torch tensor of shape (num_labels,) : class indices or numerical targets
13 | (for classification or regression, respectively). num_labels > 1 for multi-task models
14 | max_len: global fixed sequence length. Used for architectures requiring fixed length input,
15 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s
16 | Returns:
17 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input)
18 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output)
19 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor
20 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values
21 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding
22 | """
23 |
24 | batch_size = len(data)
25 | features, labels = zip(*data)
26 |
27 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
28 | lengths = [X.shape[0] for X in features] # original sequence length for each time series
29 | if max_len is None:
30 | max_len = max(lengths)
31 |
32 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim)
33 | for i in range(batch_size):
34 | end = min(lengths[i], max_len)
35 | X[i, :end, :] = features[i][:end, :]
36 |
37 | targets = torch.stack(labels, dim=0) # (batch_size, num_labels)
38 |
39 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
40 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep
41 |
42 | return X, targets, padding_masks
43 |
44 |
45 | def padding_mask(lengths, max_len=None):
46 | """
47 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
48 | where 1 means keep element at this position (time step)
49 | """
50 | batch_size = lengths.numel()
51 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types
52 | return (torch.arange(0, max_len, device=lengths.device)
53 | .type_as(lengths)
54 | .repeat(batch_size, 1)
55 | .lt(lengths.unsqueeze(1)))
56 |
57 |
58 | class Normalizer(object):
59 | """
60 | Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization.
61 | """
62 |
63 | def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None):
64 | """
65 | Args:
66 | norm_type: choose from:
67 | "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps)
68 | "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows)
69 | mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values
70 | """
71 |
72 | self.norm_type = norm_type
73 | self.mean = mean
74 | self.std = std
75 | self.min_val = min_val
76 | self.max_val = max_val
77 |
78 | def normalize(self, df):
79 | """
80 | Args:
81 | df: input dataframe
82 | Returns:
83 | df: normalized dataframe
84 | """
85 | if self.norm_type == "standardization":
86 | if self.mean is None:
87 | self.mean = df.mean()
88 | self.std = df.std()
89 | return (df - self.mean) / (self.std + np.finfo(float).eps)
90 |
91 | elif self.norm_type == "minmax":
92 | if self.max_val is None:
93 | self.max_val = df.max()
94 | self.min_val = df.min()
95 | return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps)
96 |
97 | elif self.norm_type == "per_sample_std":
98 | grouped = df.groupby(by=df.index)
99 | return (df - grouped.transform('mean')) / grouped.transform('std')
100 |
101 | elif self.norm_type == "per_sample_minmax":
102 | grouped = df.groupby(by=df.index)
103 | min_vals = grouped.transform('min')
104 | return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps)
105 |
106 | else:
107 | raise (NameError(f'Normalize method "{self.norm_type}" not implemented'))
108 |
109 |
110 | def interpolate_missing(y):
111 | """
112 | Replaces NaN values in pd.Series `y` using linear interpolation
113 | """
114 | if y.isna().any():
115 | y = y.interpolate(method='linear', limit_direction='both')
116 | return y
117 |
118 |
119 | def subsample(y, limit=256, factor=2):
120 | """
121 | If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor
122 | """
123 | if len(y) > limit:
124 | return y[::factor].reset_index(drop=True)
125 | return y
126 |
--------------------------------------------------------------------------------
/data_provider/zeroshot_task.yaml:
--------------------------------------------------------------------------------
1 | task_dataset:
2 | Solar:
3 | task_name: long_term_forecast
4 | dataset_name: solar_10_minutes
5 | dataset: Solar
6 | data: gluonts
7 | root_path: ../dataset/gluonts
8 | seq_len: 128
9 | label_len: 0
10 | pred_len: 64
11 | features: M
12 | embed: timeF
13 | enc_in: 137
14 | dec_in: 137
15 | c_out: 137
16 |
17 | Saugeen_River:
18 | task_name: long_term_forecast
19 | dataset_name: saugeenday
20 | dataset: Saugeen_River
21 | data: gluonts
22 | root_path: ../dataset/gluonts
23 | seq_len: 256
24 | label_len: 0
25 | pred_len: 128
26 | features: M
27 | embed: timeF
28 | enc_in: 1
29 | dec_in: 1
30 | c_out: 1
31 |
32 | Hospital:
33 | task_name: long_term_forecast
34 | dataset_name: hospital
35 | dataset: Hospital
36 | data: gluonts
37 | root_path: ../dataset/gluonts
38 | seq_len: 32
39 | label_len: 0
40 | pred_len: 16
41 | features: M
42 | embed: timeF
43 | enc_in: 767
44 | dec_in: 767
45 | c_out: 767
46 |
47 | Web_Traffic:
48 | task_name: long_term_forecast
49 | dataset_name: kaggle_web_traffic_without_missing
50 | dataset: Web_Traffic
51 | data: gluonts
52 | root_path: ../dataset/gluonts
53 | seq_len: 160
54 | label_len: 0
55 | pred_len: 80
56 | features: M
57 | embed: timeF
58 | enc_in: 500
59 | dec_in: 500
60 | c_out: 500
--------------------------------------------------------------------------------
/download_data_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir -p dataset
4 |
5 | # check for gdown https://github.com/wkentaro/gdown then install if necessary
6 | if ! command -v gdown &> /dev/null
7 | then
8 | echo "installing gdown, for downloading from google drive"
9 | pip install gdown
10 | fi
11 |
12 | # TimesNet data
13 | # downloads all_datasets.zip and extracts into dataset/
14 | if [ ! -f dataset/all_datasets.zip ]; then
15 | gdown "https://drive.google.com/file/d/1pmXvqWsfUeXWCMz5fqsP8WLKXR5jxY8z/view?usp=drive_link" --fuzzy -O dataset/all_datasets.zip
16 | unzip dataset/all_datasets.zip -d dataset/
17 | mv dataset/all_datasets/* dataset/
18 | rm -rf dataset/all_datasets
19 | fi
20 |
21 | # UAE data
22 | # downloads Multivariate2018_ts.zip then extacts into dataset/UAE/
23 | if [ ! -f dataset/Multivariate2018_ts.zip ]; then
24 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Archives/Multivariate2018_ts.zip" -O dataset/Multivariate2018_ts.zip
25 | unzip dataset/Multivariate2018_ts.zip -d dataset/
26 | mv dataset/Multivariate_ts dataset/UAE
27 | fi
28 |
29 | # UCR data
30 | # downloads Univariate2018_ts.zip then extacts into dataset/UCR/
31 | if [ ! -f dataset/Univariate2018_ts.zip ]; then
32 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Archives/Univariate2018_ts.zip" -O dataset/Univariate2018_ts.zip
33 | unzip dataset/Univariate2018_ts.zip -d dataset/
34 | mv dataset/Univariate_ts dataset/UCR
35 | fi
36 |
37 |
38 | # Other timeseriesclassification.com datasets:
39 |
40 | # Blink data
41 | if [ ! -f dataset/Blink.zip ]; then
42 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Blink.zip" -O dataset/Blink.zip
43 | unzip dataset/Blink.zip -d dataset/Blink
44 | fi
45 |
46 | # MotionSenseHAR data
47 | if [ ! -f dataset/MotionSenseHAR.zip ]; then
48 | wget "https://www.timeseriesclassification.com/aeon-toolkit/MotionSenseHAR.zip" -O dataset/MotionSenseHAR.zip
49 | unzip dataset/MotionSenseHAR.zip -d dataset/MotionSenseHAR
50 | fi
51 |
52 | # EMOPain data
53 | if [ ! -f dataset/EMOPain.zip ]; then
54 | wget "https://www.timeseriesclassification.com/aeon-toolkit/EMOPain.zip" -O dataset/EMOPain.zip
55 | unzip dataset/EMOPain.zip -d dataset/EMOPain
56 | fi
57 |
58 |
59 | # SharePriceIncreasen data
60 | if [ ! -f dataset/SharePriceIncrease.zip ]; then
61 | wget "https://www.timeseriesclassification.com/aeon-toolkit/SharePriceIncrease.zip" -O dataset/SharePriceIncrease.zip
62 | unzip dataset/SharePriceIncrease.zip -d dataset/SharePriceIncrease
63 | fi
64 |
65 |
66 | # AbnormalHeartbeat data
67 | if [ ! -f dataset/AbnormalHeartbeat.zip ]; then
68 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AbnormalHeartbeat.zip" -O dataset/AbnormalHeartbeat.zip
69 | unzip dataset/AbnormalHeartbeat.zip -d dataset/AbnormalHeartbeat
70 | fi
71 |
72 | # AsphaltObstacles data
73 | if [ ! -f dataset/AsphaltObstacles.zip ]; then
74 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltObstacles.zip" -O dataset/AsphaltObstacles.zip
75 | unzip dataset/AsphaltObstacles.zip -d dataset/AsphaltObstacles
76 | fi
77 |
78 | # AsphaltObstaclesCoordinates data
79 | if [ ! -f dataset/AsphaltObstaclesCoordinates.zip ]; then
80 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltObstaclesCoordinates.zip" -O dataset/AsphaltObstaclesCoordinates.zip
81 | unzip dataset/AsphaltObstaclesCoordinates.zip -d dataset/AsphaltObstaclesCoordinates
82 | fi
83 |
84 | # AsphaltPavementType data
85 | if [ ! -f dataset/AsphaltPavementType.zip ]; then
86 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltPavementType.zip" -O dataset/AsphaltPavementType.zip
87 | unzip dataset/AsphaltPavementType.zip -d dataset/AsphaltPavementType
88 | fi
89 |
90 | # AsphaltPavementTypeCoordinates data
91 | if [ ! -f dataset/AsphaltPavementTypeCoordinates.zip ]; then
92 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltPavementTypeCoordinates.zip" -O dataset/AsphaltPavementTypeCoordinates.zip
93 | unzip dataset/AsphaltPavementTypeCoordinates.zip -d dataset/AsphaltPavementTypeCoordinates
94 | fi
95 |
96 | # AsphaltRegularity data
97 | if [ ! -f dataset/AsphaltRegularity.zip ]; then
98 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltRegularity.zip" -O dataset/AsphaltRegularity.zip
99 | unzip dataset/AsphaltRegularity.zip -d dataset/AsphaltRegularity
100 | fi
101 |
102 | # AsphaltRegularityCoordinates data
103 | if [ ! -f dataset/AsphaltRegularityCoordinates.zip ]; then
104 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltRegularityCoordinates.zip" -O dataset/AsphaltRegularityCoordinates.zip
105 | unzip dataset/AsphaltRegularityCoordinates.zip -d dataset/AsphaltRegularityCoordinates
106 | fi
107 |
108 | # BinaryHeartbeat data
109 | if [ ! -f dataset/BinaryHeartbeat.zip ]; then
110 | wget "https://www.timeseriesclassification.com/aeon-toolkit/BinaryHeartbeat.zip" -O dataset/BinaryHeartbeat.zip
111 | unzip dataset/BinaryHeartbeat.zip -d dataset/BinaryHeartbeat
112 | fi
113 |
114 | # CatsDogs data
115 | if [ ! -f dataset/CatsDogs.zip ]; then
116 | wget "https://www.timeseriesclassification.com/aeon-toolkit/CatsDogs.zip" -O dataset/CatsDogs.zip
117 | unzip dataset/CatsDogs.zip -d dataset/CatsDogs
118 | fi
119 |
120 | # Colposcopy data
121 | if [ ! -f dataset/Colposcopy.zip ]; then
122 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Colposcopy.zip" -O dataset/Colposcopy.zip
123 | unzip dataset/Colposcopy.zip -d dataset/Colposcopy
124 | fi
125 |
126 | # CounterMovementJump data
127 | if [ ! -f dataset/CounterMovementJump.zip ]; then
128 | wget "https://www.timeseriesclassification.com/aeon-toolkit/CounterMovementJump.zip" -O dataset/CounterMovementJump.zip
129 | unzip dataset/CounterMovementJump.zip -d dataset/CounterMovementJump
130 | fi
131 |
132 | # DucksAndGeese data
133 | if [ ! -f dataset/DucksAndGeese.zip ]; then
134 | wget "https://www.timeseriesclassification.com/aeon-toolkit/DucksAndGeese.zip" -O dataset/DucksAndGeese.zip
135 | unzip dataset/DucksAndGeese.zip -d dataset/DucksAndGeese
136 | fi
137 |
138 | # ElectricDeviceDetection data
139 | if [ ! -f dataset/ElectricDeviceDetection.zip ]; then
140 | wget "https://www.timeseriesclassification.com/aeon-toolkit/ElectricDeviceDetection.zip" -O dataset/ElectricDeviceDetection.zip
141 | unzip dataset/ElectricDeviceDetection.zip -d dataset/ElectricDeviceDetection
142 | fi
143 |
144 | # Epilepsy2 data
145 | if [ ! -f dataset/Epilepsy2.zip ]; then
146 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Epilepsy2.zip" -O dataset/Epilepsy2.zip
147 | unzip dataset/Epilepsy2.zip -d dataset/Epilepsy2
148 | fi
149 |
150 | # EyesOpenShut data
151 | if [ ! -f dataset/EyesOpenShut.zip ]; then
152 | wget "https://www.timeseriesclassification.com/aeon-toolkit/EyesOpenShut.zip" -O dataset/EyesOpenShut.zip
153 | unzip dataset/EyesOpenShut.zip -d dataset/EyesOpenShut
154 | fi
155 |
156 | # FaultDetectionA data
157 | if [ ! -f dataset/FaultDetectionA.zip ]; then
158 | wget "https://www.timeseriesclassification.com/aeon-toolkit/FaultDetectionA.zip" -O dataset/FaultDetectionA.zip
159 | unzip dataset/FaultDetectionA.zip -d dataset/FaultDetectionA
160 | fi
161 |
162 | # FruitFlies data
163 | if [ ! -f dataset/FruitFlies.zip ]; then
164 | wget "https://www.timeseriesclassification.com/aeon-toolkit/FruitFlies.zip" -O dataset/FruitFlies.zip
165 | unzip dataset/FruitFlies.zip -d dataset/FruitFlies
166 | fi
167 |
168 | # InsectSound data
169 | if [ ! -f dataset/InsectSound.zip ]; then
170 | wget "https://www.timeseriesclassification.com/aeon-toolkit/InsectSound.zip" -O dataset/InsectSound.zip
171 | unzip dataset/InsectSound.zip -d dataset/InsectSound
172 | fi
173 |
174 | # KeplerLightCurves data
175 | if [ ! -f dataset/KeplerLightCurves.zip ]; then
176 | wget "https://www.timeseriesclassification.com/aeon-toolkit/KeplerLightCurves.zip" -O dataset/KeplerLightCurves.zip
177 | unzip dataset/KeplerLightCurves.zip -d dataset/KeplerLightCurves
178 | fi
179 |
180 | # MindReading data
181 | if [ ! -f dataset/MindReading.zip ]; then
182 | wget "https://www.timeseriesclassification.com/aeon-toolkit/MindReading.zip" -O dataset/MindReading.zip
183 | unzip dataset/MindReading.zip -d dataset/MindReading
184 | fi
185 |
186 | # MosquitoSound data
187 | if [ ! -f dataset/MosquitoSound.zip ]; then
188 | wget "https://www.timeseriesclassification.com/aeon-toolkit/MosquitoSound.zip" -O dataset/MosquitoSound.zip
189 | unzip dataset/MosquitoSound.zip -d dataset/MosquitoSound
190 | fi
191 |
192 | # NerveDamage data
193 | if [ ! -f dataset/NerveDamage.zip ]; then
194 | wget "https://www.timeseriesclassification.com/aeon-toolkit/NerveDamage.zip" -O dataset/NerveDamage.zip
195 | unzip dataset/NerveDamage.zip -d dataset/NerveDamage
196 | fi
197 |
198 | # RightWhaleCalls data
199 | if [ ! -f dataset/RightWhaleCalls.zip ]; then
200 | wget "https://www.timeseriesclassification.com/aeon-toolkit/RightWhaleCalls.zip" -O dataset/RightWhaleCalls.zip
201 | unzip dataset/RightWhaleCalls.zip -d dataset/RightWhaleCalls
202 | fi
203 |
204 | # Sleep data
205 | if [ ! -f dataset/Sleep.zip ]; then
206 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Sleep.zip" -O dataset/Sleep.zip
207 | unzip dataset/Sleep.zip -d dataset/Sleep
208 | fi
209 |
210 | # Tiselac data
211 | if [ ! -f dataset/Tiselac.zip ]; then
212 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Tiselac.zip" -O dataset/Tiselac.zip
213 | unzip dataset/Tiselac.zip -d dataset/Tiselac
214 | fi
215 |
216 | # UrbanSound data
217 | if [ ! -f dataset/UrbanSound.zip ]; then
218 | wget "https://www.timeseriesclassification.com/aeon-toolkit/UrbanSound.zip" -O dataset/UrbanSound.zip
219 | unzip dataset/UrbanSound.zip -d dataset/UrbanSound
220 | fi
221 |
222 | # WalkingSittingStanding data
223 | if [ ! -f dataset/WalkingSittingStanding.zip ]; then
224 | wget "https://www.timeseriesclassification.com/aeon-toolkit/WalkingSittingStanding.zip" -O dataset/WalkingSittingStanding.zip
225 | unzip dataset/WalkingSittingStanding.zip -d dataset/WalkingSittingStanding
226 | fi
227 |
--------------------------------------------------------------------------------
/exp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/exp/__init__.py
--------------------------------------------------------------------------------
/exp/exp_pretrain.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_factory import data_provider
2 | from utils.tools import cosine_scheduler
3 | from utils.tools import NativeScalerWithGradNormCount as NativeScaler
4 | from utils.losses import UnifiedMaskRecLoss
5 | from utils.dataloader import BalancedDataLoaderIterator
6 | from utils.ddp import is_main_process, get_world_size
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch import optim
11 | import torch.distributed as dist
12 |
13 | import os
14 | import time
15 | import warnings
16 | import numpy as np
17 | import yaml
18 | import wandb
19 | import importlib
20 | import sys
21 |
22 | warnings.filterwarnings('ignore')
23 |
24 | def custom_print_decorator(func):
25 | def wrapper(*args, **kwargs):
26 | text = ' '.join(map(str, args))
27 | if 'file' not in kwargs or kwargs['file'] is None:
28 | sys.stdout.write(text + '\n')
29 | else:
30 | kwargs['file'].write(text + '\n')
31 |
32 | if 'folder' in kwargs and kwargs['folder']:
33 | with open(f'{kwargs["folder"]}/finetune_output.log', 'a') as log_file:
34 | log_file.write(text + '\n')
35 | if 'folder' in kwargs:
36 | del kwargs['folder']
37 | if 'file' in kwargs:
38 | del kwargs['file']
39 | return wrapper
40 |
41 |
42 | # replace print to save all print into log files
43 | print = custom_print_decorator(print)
44 |
45 |
46 | def read_task_data_config(config_path):
47 | with open(config_path, 'r') as config_file:
48 | config = yaml.load(config_file, Loader=yaml.FullLoader)
49 | task_dataset_config = config.get('task_dataset', {})
50 | return task_dataset_config
51 |
52 |
53 | def get_task_data_config_list(task_data_config, default_batch_size=None):
54 | task_data_config_list = []
55 |
56 | for task_name, task_config in task_data_config.items():
57 | task_config['max_batch'] = default_batch_size
58 | task_data_config_list.append([task_name, task_config])
59 |
60 | return task_data_config_list
61 |
62 |
63 | def init_and_merge_datasets(data_loader_list):
64 | dataloader = BalancedDataLoaderIterator(data_loader_list)
65 | train_steps = dataloader.__len__()
66 |
67 | return dataloader, train_steps
68 |
69 |
70 | class Exp_All_Task(object):
71 | def __init__(self, args):
72 | super(Exp_All_Task, self).__init__()
73 |
74 | self.args = args
75 | self.task_data_config = read_task_data_config(
76 | self.args.task_data_config_path)
77 | self.task_data_config_list = get_task_data_config_list(
78 | self.task_data_config, default_batch_size=self.args.batch_size)
79 | device_id = dist.get_rank() % torch.cuda.device_count()
80 | print("this device_id:", device_id)
81 | self.device_id = device_id
82 |
83 | def _build_model(self, ddp=True):
84 | module = importlib.import_module("models."+self.args.model)
85 | model = module.Model(
86 | self.args, self.task_data_config_list, pretrain=True).to(self.device_id)
87 | if ddp:
88 | model = nn.parallel.DistributedDataParallel(
89 | model, device_ids=[self.device_id], find_unused_parameters=True)
90 | return model.to(self.device_id)
91 |
92 | def _get_data(self, flag):
93 | data_set_list = []
94 | data_loader_list = []
95 | for task_data_name, task_config in self.task_data_config.items():
96 | print("loading dataset:", task_data_name, folder=self.path)
97 | if task_config['data'] == 'UEA' and flag == 'val':
98 | # TODO strange that no val set is used for classification. Set to test set for val
99 | flag = 'test'
100 | data_set, data_loader = data_provider(
101 | self.args, task_config, flag, ddp=True)
102 | data_set_list.append(data_set)
103 | data_loader_list.append(data_loader)
104 | return data_set_list, data_loader_list
105 |
106 | def _select_optimizer(self):
107 | eff_batch_size = self.args.batch_size * self.args.acc_it * get_world_size()
108 | real_learning_rate = self.args.learning_rate * eff_batch_size / 32
109 | print("base lr: %.2e" % (self.args.learning_rate * 32 / eff_batch_size))
110 | print("actual lr: %.2e" % real_learning_rate)
111 | self.real_learning_rate = real_learning_rate
112 |
113 | print("accumulate grad iterations: %d" % self.args.acc_it)
114 | print("effective batch size: %d" % eff_batch_size)
115 | model_optim = optim.Adam(self.model.parameters(
116 | ), lr=real_learning_rate, betas=(0.9, self.args.beta2), weight_decay=self.args.weight_decay, eps=self.args.eps)
117 | return model_optim
118 |
119 | def train(self, setting):
120 | path = os.path.join(self.args.checkpoints, setting)
121 | if not os.path.exists(path) and is_main_process():
122 | os.makedirs(path)
123 | self.path = path
124 |
125 | torch.cuda.synchronize()
126 | dist.barrier()
127 |
128 | # Data loader
129 | _, train_loader_list = self._get_data(flag='train')
130 | data_loader_cycle, train_steps = init_and_merge_datasets(
131 | train_loader_list)
132 |
133 | # Set up batch size for each task
134 | if self.args.memory_check:
135 | self.memory_check(data_loader_cycle)
136 | torch.cuda.empty_cache()
137 |
138 | torch.cuda.synchronize()
139 | dist.barrier()
140 |
141 | # Model
142 | self.model = self._build_model()
143 |
144 | pytorch_total_params = sum(p.numel() for p in self.model.parameters())
145 | print("Parameters number {} M".format(
146 | pytorch_total_params/1e6), folder=self.path)
147 | print("{} steps for each epoch".format(train_steps), folder=self.path)
148 |
149 | # Optimizer
150 | model_optim = self._select_optimizer()
151 | lr_schedule = cosine_scheduler(
152 | self.real_learning_rate,
153 | self.args.min_lr,
154 | self.args.train_epochs, train_steps,
155 | warmup_epochs=self.args.warmup_epochs,
156 | )
157 |
158 | # Loss
159 | criterion = UnifiedMaskRecLoss().to(self.device_id)
160 | scaler = NativeScaler()
161 |
162 | for epoch in range(self.args.train_epochs):
163 | train_loss = self.train_one_epoch(
164 | model_optim, data_loader_cycle, criterion, epoch, train_steps, scaler, lr_schedule)
165 |
166 | print("Epoch: {0}, Steps: {1} | Avg Train Loss: {2:.7f}".format(
167 | epoch + 1, train_steps, train_loss), folder=self.path)
168 | if is_main_process():
169 | wandb.log({'train_loss_avg': train_loss})
170 |
171 | if is_main_process():
172 | save_dict = {
173 | 'student': self.model.state_dict(),
174 | 'optimizer': model_optim.state_dict(),
175 | 'epoch': epoch + 1,
176 | 'args': self.args,
177 | }
178 |
179 | torch.save(save_dict, path + '/' + 'pretrain_checkpoint.pth')
180 |
181 | return self.model
182 |
183 | def train_one_epoch(self, model_optim, data_loader_cycle, criterion, epoch, train_steps, scaler, lr_schedule):
184 | current_device = torch.cuda.current_device()
185 | train_loss_set = []
186 |
187 | acc_it = self.args.acc_it
188 | max_norm = self.args.clip_grad
189 | min_keep_ratio = self.args.min_keep_ratio
190 |
191 | self.model.train()
192 | epoch_time = time.time()
193 | self.model.zero_grad(set_to_none=True)
194 | loss_sum_display = 0
195 |
196 | for i, (sample_init, task_id) in enumerate(data_loader_cycle):
197 | it = train_steps * epoch + i
198 | for _, param_group in enumerate(model_optim.param_groups):
199 | param_group["lr"] = lr_schedule[it]
200 |
201 | # Get batch data based on the real batch size of each task: avoid OOM for large samples
202 | task_name = self.task_data_config_list[task_id][1]['task_name']
203 | small_batch_size = self.task_data_config_list[task_id][1]['max_batch']
204 | sample_list = self.get_multi_source_data(
205 | sample_init, task_name, small_batch_size, min_keep_ratio=min_keep_ratio)
206 | len_sample_list = len(sample_list)
207 |
208 | # Accumulate gradients of mulitple samples
209 | for sample_idx in range(len_sample_list):
210 | sample = sample_list[sample_idx]
211 | x_enc, x_mark_enc, pad_mask = sample
212 | with torch.cuda.amp.autocast():
213 | model_output = self.model(
214 | x_enc=x_enc, x_mark_enc=x_mark_enc, task_id=task_id, task_name=task_name, enable_mask=True)
215 | loss_dict = criterion(model_output, x_enc, pad_mask)
216 | loss = loss_dict['loss']
217 | loss /= acc_it
218 | loss /= len_sample_list
219 | if sample_idx < len_sample_list-1:
220 | norm_value = scaler(loss, model_optim, clip_grad=max_norm,
221 | parameters=self.model.parameters(), create_graph=False, update_grad=False)
222 |
223 | loss_display = loss.item()*len_sample_list*acc_it
224 | train_loss_set.append(loss_display)
225 |
226 | norm_value = scaler(loss, model_optim, clip_grad=max_norm,
227 | parameters=self.model.parameters(), create_graph=False, update_grad=((i + 1) % acc_it == 0))
228 |
229 | if (i+1) % acc_it == 0:
230 | model_optim.zero_grad()
231 | torch.cuda.synchronize()
232 |
233 | loss_sum_display += loss_display
234 |
235 | # release memory to avoid OOM
236 | del sample_init
237 | del sample_list
238 | if torch.cuda.memory_reserved(current_device) > 30*1e9:
239 | torch.cuda.empty_cache()
240 |
241 | if is_main_process():
242 | wandb_loss_dict = {
243 | 'norm': norm_value if norm_value is not None else 0,
244 | 'train_cls_loss_'+self.task_data_config_list[task_id][0]: loss_dict['cls_loss'].item(),
245 | 'train_mask_loss_'+self.task_data_config_list[task_id][0]: loss_dict['mask_loss'].item(),
246 | 'train_sum_loss_'+self.task_data_config_list[task_id][0]: loss_dict['loss'].item(),
247 | "loss_avg": loss_sum_display/(i+1)
248 | }
249 | wandb.log(wandb_loss_dict)
250 |
251 | if (i + 1) % 50 == 0 and is_main_process():
252 | print("\titers: {0}, epoch: {1} | lr: {2:.5} | loss_avg: {3} | current_loss: {4} |current data: {5}".format(
253 | i + 1, epoch + 1, lr_schedule[it], loss_sum_display/(i+1), loss.item() * acc_it, task_name), folder=self.path)
254 |
255 | if is_main_process():
256 | print("Epoch: {} cost time: {}".format(
257 | epoch + 1, time.time() - epoch_time), folder=self.path)
258 | train_loss = np.average(train_loss_set)
259 |
260 | return train_loss
261 |
262 | def get_multi_source_data(self, this_batch, task_name, small_batch_size, min_keep_ratio=None):
263 | """
264 | Splits the input batch into smaller batches based on the specified small_batch_size.
265 |
266 | Args:
267 | this_batch (tuple): The input batch containing all data of a task.
268 | task_name (str): The name of the task.
269 | small_batch_size (int): The size of the smaller batches to split the data into.
270 | min_keep_ratio (float, optional): The minimum ratio of data to keep in each smaller batch.
271 |
272 | Returns:
273 | list: A list of tuples, where each tuple contains a smaller batch of data, marks, and padding masks.
274 | """
275 |
276 | def split_tensor(tensor, size):
277 | return [tensor[i:min(i + size, tensor.size(0))] for i in range(0, tensor.size(0), size)]
278 |
279 | if "long_term_forecast" in task_name:
280 | batch_x, _, batch_x_mark, _ = this_batch
281 | batch_x = batch_x.float().to(self.device_id)
282 | batch_x_mark = batch_x_mark.float().to(self.device_id)
283 | batch_x_mark = batch_x_mark.max(dim=-1)[0]
284 | padding_mask = torch.ones(
285 | (batch_x.shape[0], batch_x.shape[1]), dtype=torch.bool).to(self.device_id)
286 | elif "classification" in task_name:
287 | batch_x, _, padding_mask = this_batch
288 | batch_x = batch_x.float().to(self.device_id)
289 | batch_x_mark = padding_mask.float().to(self.device_id)
290 | padding_mask = batch_x_mark.bool().to(self.device_id)
291 |
292 | if min_keep_ratio is not None:
293 | keep_ratios = torch.rand(
294 | 1, device=batch_x.device) * (1.0 - min_keep_ratio) + min_keep_ratio
295 | L = batch_x.shape[1]
296 | len_keeps = (L * keep_ratios).long()
297 | len_keeps = (torch.ceil(len_keeps/self.args.patch_len)
298 | )*self.args.patch_len
299 | len_keeps = len_keeps.int()
300 |
301 | batch_x = batch_x[:, :len_keeps]
302 | batch_x_mark = batch_x_mark[:, :len_keeps]
303 | padding_mask = padding_mask[:, :len_keeps]
304 |
305 | split_batch_x = split_tensor(batch_x, small_batch_size)
306 | split_batch_x_mark = split_tensor(batch_x_mark, small_batch_size)
307 | split_padding_mask = split_tensor(padding_mask, small_batch_size)
308 |
309 | return list(zip(split_batch_x, split_batch_x_mark, split_padding_mask))
310 |
311 | def memory_check(self, data_loader_cycle, holdout_memory=6):
312 | """
313 | Checks the memory usage of the model by gradually increasing the batch size until it reaches the maximum batch size that can be supported without running out of memory.
314 |
315 | Args:
316 | data_loader_cycle (DataLoaderCycle): The data loader cycle object.
317 | holdout_memory (int): The amount of memory (in GB) to hold out for other operations.
318 |
319 | Returns:
320 | None
321 | """
322 | num_elements = holdout_memory * 1024 * 1024 * 1024 // 4
323 | extra_mem = torch.empty(
324 | num_elements, dtype=torch.float32, device=self.device_id)
325 |
326 | model_tmp = self._build_model(ddp=False)
327 | criterion = UnifiedMaskRecLoss().to(self.device_id)
328 | model_tmp.train()
329 | model_tmp.zero_grad(set_to_none=True)
330 |
331 | for data_loader_id in range(data_loader_cycle.num_dataloaders):
332 | batch_size = 1
333 | max_batch_size = 0
334 | torch.cuda.synchronize()
335 | model_tmp.zero_grad(set_to_none=True)
336 | while True:
337 | try:
338 | sample, task_id = data_loader_cycle.generate_fake_samples_for_batch(
339 | data_loader_id, batch_size)
340 | task_name = self.task_data_config_list[task_id][1]['task_name']
341 | if "long_term_forecast" in task_name:
342 | batch_x, _, batch_x_mark, _ = sample
343 | batch_x = batch_x.float().to(self.device_id)
344 | batch_x_mark = batch_x_mark.float().to(self.device_id)
345 | elif "classification" in task_name:
346 | batch_x, _, batch_x_mark = sample
347 | batch_x = batch_x.float().to(self.device_id)
348 | batch_x_mark = torch.ones(
349 | (batch_x.shape[0], batch_x.shape[1]), dtype=torch.bool).to(self.device_id)
350 |
351 | print(task_id, task_name,
352 | sample[0].shape, "max batch size", max_batch_size)
353 | with torch.cuda.amp.autocast():
354 | model_output = model_tmp(
355 | x_enc=batch_x, x_mark_enc=batch_x_mark, task_id=task_id, task_name=task_name, enable_mask=True)
356 | loss = 0.0
357 | for each in model_output:
358 | if each is not None:
359 | loss += each.sum()
360 |
361 | loss.backward()
362 | max_batch_size = batch_size
363 | batch_size *= 2
364 |
365 | if max_batch_size >= self.args.batch_size:
366 | print("can support default batchsize:",
367 | self.args.batch_size, max_batch_size)
368 | self.task_data_config_list[task_id][1]['max_batch'] = max_batch_size
369 | self.task_data_config_list[task_id][1]['checkpointing'] = False
370 | break
371 |
372 | except Exception as e:
373 | task_name = self.task_data_config_list[task_id][1]['task_name']
374 | print(task_id, "max batch size:", max_batch_size)
375 | self.task_data_config_list[task_id][1]['max_batch'] = max_batch_size
376 | print(f"An exception occurred: {e}")
377 | del model_tmp
378 | del criterion
379 | torch.cuda.empty_cache()
380 | model_tmp = self._build_model(ddp=False)
381 | criterion = UnifiedMaskRecLoss().to(self.device_id)
382 | break
383 | del extra_mem
384 | del model_tmp
385 | del criterion
386 | torch.cuda.empty_cache()
387 | print(self.task_data_config_list)
388 | return
389 |
--------------------------------------------------------------------------------
/models/UniTS_zeroshot.py:
--------------------------------------------------------------------------------
1 | """
2 | UniTS
3 | """
4 | import math
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 | from timm.layers import Mlp, DropPath
10 | from timm.layers.helpers import to_2tuple
11 |
12 |
13 | def calculate_unfold_output_length(input_length, size, step):
14 | # Calculate the number of windows
15 | num_windows = (input_length - size) // step + 1
16 | return num_windows
17 |
18 |
19 | class CrossAttention(nn.Module):
20 | def __init__(
21 | self,
22 | dim,
23 | num_heads=8,
24 | qkv_bias=False,
25 | qk_norm=False,
26 | attn_drop=0.,
27 | proj_drop=0.,
28 | norm_layer=nn.LayerNorm,
29 | var_num=None,
30 | ):
31 | super().__init__()
32 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
33 | self.num_heads = num_heads
34 | self.head_dim = dim // num_heads
35 | self.scale = self.head_dim ** -0.5
36 |
37 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
38 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
39 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
40 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
41 | self.attn_drop = nn.Dropout(attn_drop)
42 | self.proj = nn.Linear(dim, dim)
43 | self.proj_drop = nn.Dropout(proj_drop)
44 | if var_num is not None:
45 | self.template = nn.Parameter(
46 | torch.zeros(var_num, dim), requires_grad=True)
47 | torch.nn.init.normal_(self.template, std=.02)
48 | self.var_num = var_num
49 |
50 | def forward(self, x, query=None):
51 | B, N, C = x.shape
52 | if query is not None:
53 | q = self.q(query).reshape(
54 | B, query.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3)
55 | q = self.q_norm(q)
56 | var_num = query.shape[1]
57 | else:
58 | q = self.q(self.template).reshape(1, self.var_num,
59 | self.num_heads, self.head_dim).permute(0, 2, 1, 3)
60 | q = self.q_norm(q)
61 | q = q.repeat(B, 1, 1, 1)
62 | var_num = self.var_num
63 | kv = self.kv(x).reshape(B, N, 2, self.num_heads,
64 | self.head_dim).permute(2, 0, 3, 1, 4)
65 | k, v = kv.unbind(0)
66 | k = self.k_norm(k)
67 |
68 | x = F.scaled_dot_product_attention(
69 | q, k, v,
70 | dropout_p=self.attn_drop.p if self.training else 0.,
71 | )
72 |
73 | x = x.transpose(1, 2).reshape(B, var_num, C)
74 | x = self.proj(x)
75 | x = self.proj_drop(x)
76 | return x
77 |
78 |
79 | class DynamicLinear(nn.Module):
80 | """
81 | A dynamic linear layer that can interpolate the weight size to support any given input and output feature dimension.
82 | """
83 |
84 | def __init__(self, in_features=None, out_features=None, fixed_in=0, bias=True):
85 | super(DynamicLinear, self).__init__()
86 | assert fixed_in < in_features, "fixed_in < in_features is required !!!"
87 | self.in_features = in_features
88 | self.out_features = out_features
89 | self.weights = nn.Parameter(torch.Tensor(out_features, in_features))
90 | self.bias = nn.Parameter(torch.Tensor(out_features))
91 | self.fixed_in = fixed_in
92 |
93 | self.reset_parameters()
94 |
95 | def reset_parameters(self):
96 | nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
97 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
98 | bound = 1 / math.sqrt(fan_in)
99 | nn.init.uniform_(self.bias, -bound, bound)
100 |
101 | def forward(self, x, out_features):
102 | """
103 | Forward pass for the dynamic linear layer.
104 | """
105 | fixed_weights = self.weights[:, :self.fixed_in]
106 | dynamic_weights = self.weights[:, self.fixed_in:]
107 | this_bias = self.bias
108 | in_features = x.shape[-1]
109 |
110 | if in_features != self.weights.size(1) or out_features != self.weights.size(0):
111 | dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=(
112 | out_features, in_features-self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
113 | if self.fixed_in != 0:
114 | fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=(
115 | out_features, self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
116 | if out_features != self.weights.size(0):
117 | this_bias = F.interpolate(this_bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), size=(
118 | 1, out_features), mode='bilinear', align_corners=False).squeeze(0).squeeze(0).squeeze(0)
119 | return F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1), this_bias)
120 |
121 |
122 | class DynamicLinearMlp(nn.Module):
123 | def __init__(
124 | self,
125 | in_features,
126 | hidden_features=None,
127 | out_features=None,
128 | act_layer=nn.GELU,
129 | norm_layer=None,
130 | bias=True,
131 | drop=0.,
132 | prefix_token_length=None,
133 | group=1,
134 | ):
135 | super().__init__()
136 | out_features = out_features or in_features
137 | hidden_features = hidden_features or in_features
138 | bias = to_2tuple(bias)
139 | drop_probs = to_2tuple(drop)
140 |
141 | self.fc1 = nn.Conv1d(in_features, hidden_features,
142 | 3, groups=group, bias=bias[0], padding=1)
143 | self.act = act_layer()
144 | self.drop1 = nn.Dropout(drop_probs[0])
145 |
146 | self.norm = norm_layer(
147 | hidden_features) if norm_layer is not None else nn.Identity()
148 | self.seq_fc = DynamicLinear(
149 | hidden_features//4, hidden_features//4, bias=bias[1], fixed_in=prefix_token_length)
150 | self.prompt_fc = DynamicLinear(
151 | hidden_features//4, prefix_token_length, bias=bias[1], fixed_in=prefix_token_length)
152 |
153 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
154 | self.drop2 = nn.Dropout(drop_probs[1])
155 | self.hidden_features = hidden_features
156 | self.prefix_token_length = prefix_token_length
157 |
158 | def dynamic_linear(self, x, prefix_seq_len):
159 | x_func = x[:, :, prefix_seq_len:]
160 | x_seq = x[:, :, :prefix_seq_len]
161 | x_seq_out = self.seq_fc(
162 | x_seq, x_seq.shape[-1]-self.prefix_token_length)
163 | x_prompt = self.prompt_fc(x_seq, self.prefix_token_length)
164 | x = torch.cat((x_prompt, x_seq_out, x_func), dim=-1)
165 | return x
166 |
167 | def split_dynamic_linear(self, x, prefix_seq_len):
168 | x1, x2 = x.chunk(2, dim=-2)
169 | x1 = self.dynamic_linear(x1, prefix_seq_len)
170 | return torch.cat((x1, x2), dim=-2)
171 |
172 | def forward(self, x, prefix_seq_len, dim=2):
173 | n, var, l, c = x.shape
174 | x = x.view(-1, l, c)
175 | x = x.transpose(-1, -2)
176 | x = self.fc1(x)
177 | x = self.split_dynamic_linear(x, prefix_seq_len)
178 | x = self.act(x)
179 | x = self.drop1(x)
180 | x = x.transpose(1, 2)
181 | x = self.norm(x)
182 | x = self.fc2(x).view(n, var, l, c)
183 | x = self.drop2(x)
184 | return x
185 |
186 |
187 | class LearnablePositionalEmbedding(nn.Module):
188 | def __init__(self, d_model, max_len=5000):
189 | super(LearnablePositionalEmbedding, self).__init__()
190 | # Compute the positional encodings once in log space.
191 | self.pe = nn.Parameter(torch.zeros(
192 | 1, 1, max_len, d_model), requires_grad=True)
193 |
194 | pe = torch.zeros(max_len, d_model).float()
195 | position = torch.arange(0, max_len).float().unsqueeze(1)
196 | div_term = (torch.arange(0, d_model, 2).float()
197 | * -(math.log(10000.0) / d_model)).exp()
198 |
199 | pe[:, 0::2] = torch.sin(position * div_term)
200 | pe[:, 1::2] = torch.cos(position * div_term)
201 |
202 | pe = pe.unsqueeze(0).unsqueeze(0)
203 | self.pe.data.copy_(pe.float())
204 | del pe
205 |
206 | def forward(self, x, offset=0):
207 | return self.pe[:, :, offset:offset+x.size(2)]
208 |
209 |
210 | class SeqAttention(nn.Module):
211 |
212 | def __init__(
213 | self,
214 | dim,
215 | num_heads=8,
216 | qkv_bias=False,
217 | qk_norm=False,
218 | attn_drop=0.,
219 | proj_drop=0.,
220 | norm_layer=nn.LayerNorm,
221 | ):
222 | super().__init__()
223 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
224 | self.num_heads = num_heads
225 | self.head_dim = dim // num_heads
226 | self.scale = self.head_dim ** -0.5
227 |
228 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
229 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
230 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
231 | self.attn_drop = nn.Dropout(attn_drop)
232 | self.proj = nn.Linear(dim, dim)
233 | self.proj_drop = nn.Dropout(proj_drop)
234 |
235 | def forward(self, x, attn_mask=None):
236 | B, N, C = x.shape
237 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
238 | self.head_dim).permute(2, 0, 3, 1, 4)
239 | q, k, v = qkv.unbind(0)
240 | q, k = self.q_norm(q), self.k_norm(k)
241 | x = F.scaled_dot_product_attention(
242 | q, k, v, # attn_mask=attn_mask,
243 | dropout_p=self.attn_drop.p if self.training else 0.,
244 | )
245 |
246 | x = x.transpose(1, 2).reshape(B, N, C)
247 | x = self.proj(x)
248 | x = self.proj_drop(x)
249 | return x
250 |
251 |
252 | class VarAttention(nn.Module):
253 |
254 | def __init__(
255 | self,
256 | dim,
257 | num_heads=8,
258 | qkv_bias=False,
259 | qk_norm=False,
260 | attn_drop=0.,
261 | proj_drop=0.,
262 | norm_layer=nn.LayerNorm,
263 | ):
264 | super().__init__()
265 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
266 | self.num_heads = num_heads
267 | self.head_dim = dim // num_heads
268 | self.scale = self.head_dim ** -0.5
269 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
270 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
271 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
272 | self.attn_drop = nn.Dropout(attn_drop)
273 | self.proj = nn.Linear(dim, dim)
274 | self.proj_drop = nn.Dropout(proj_drop)
275 |
276 | def forward(self, x):
277 | B, N, P, C = x.shape
278 |
279 | qkv = self.qkv(x).reshape(B, N, P, 3, self.num_heads,
280 | self.head_dim).permute(3, 0, 2, 4, 1, 5)
281 | q, k, v = qkv.unbind(0)
282 | q, k = self.q_norm(q), self.k_norm(k)
283 |
284 | q = q.mean(dim=1, keepdim=False)
285 | k = k.mean(dim=1, keepdim=False)
286 | v = v.permute(0, 2, 3, 4, 1).reshape(B, self.num_heads, N, -1)
287 |
288 | x = F.scaled_dot_product_attention(
289 | q, k, v,
290 | dropout_p=self.attn_drop.p if self.training else 0.,
291 | )
292 |
293 | x = x.view(B, self.num_heads, N, -1, P).permute(0,
294 | 2, 4, 1, 3).reshape(B, N, P, -1)
295 | x = self.proj(x)
296 | x = self.proj_drop(x)
297 | return x
298 |
299 |
300 | class GateLayer(nn.Module):
301 | def __init__(self, dim, init_values=1e-5, inplace=False):
302 | super().__init__()
303 | self.inplace = inplace
304 | self.gate = nn.Linear(dim, 1)
305 |
306 | def forward(self, x):
307 | gate_value = self.gate(x)
308 | return gate_value.sigmoid() * x
309 |
310 |
311 | class SeqAttBlock(nn.Module):
312 |
313 | def __init__(
314 | self,
315 | dim,
316 | num_heads,
317 | qkv_bias=False,
318 | qk_norm=False,
319 | proj_drop=0.,
320 | attn_drop=0.,
321 | init_values=None,
322 | drop_path=0.,
323 | norm_layer=nn.LayerNorm,
324 | ):
325 | super().__init__()
326 | self.norm1 = norm_layer(dim)
327 | self.attn_seq = SeqAttention(
328 | dim,
329 | num_heads=num_heads,
330 | qkv_bias=qkv_bias,
331 | qk_norm=qk_norm,
332 | attn_drop=attn_drop,
333 | proj_drop=proj_drop,
334 | norm_layer=norm_layer,
335 | )
336 |
337 | self.ls1 = GateLayer(dim, init_values=init_values)
338 | self.drop_path1 = DropPath(
339 | drop_path) if drop_path > 0. else nn.Identity()
340 | self.proj = nn.Linear(dim, dim)
341 |
342 | def forward(self, x, attn_mask):
343 | x_input = x
344 | x = self.norm1(x)
345 | n_vars, n_seqs = x.shape[1], x.shape[2]
346 | x = torch.reshape(
347 | x, (-1, x.shape[-2], x.shape[-1]))
348 | x = self.attn_seq(x, attn_mask)
349 | x = torch.reshape(
350 | x, (-1, n_vars, n_seqs, x.shape[-1]))
351 | x = x_input + self.drop_path1(self.ls1(x))
352 | return x
353 |
354 |
355 | class VarAttBlock(nn.Module):
356 |
357 | def __init__(
358 | self,
359 | dim,
360 | num_heads,
361 | qkv_bias=False,
362 | qk_norm=False,
363 | proj_drop=0.,
364 | attn_drop=0.,
365 | init_values=None,
366 | drop_path=0.,
367 | norm_layer=nn.LayerNorm,
368 | ):
369 | super().__init__()
370 | self.norm1 = norm_layer(dim)
371 | self.attn_var = VarAttention(
372 | dim,
373 | num_heads=num_heads,
374 | qkv_bias=qkv_bias,
375 | qk_norm=qk_norm,
376 | attn_drop=attn_drop,
377 | proj_drop=proj_drop,
378 | norm_layer=norm_layer,
379 | )
380 | self.ls1 = GateLayer(dim, init_values=init_values)
381 | self.drop_path1 = DropPath(
382 | drop_path) if drop_path > 0. else nn.Identity()
383 | self.proj = nn.Linear(dim, dim)
384 |
385 | def forward(self, x):
386 | x = x + self.drop_path1(self.ls1(self.attn_var(self.norm1(x))))
387 | return x
388 |
389 |
390 | class MLPBlock(nn.Module):
391 |
392 | def __init__(
393 | self,
394 | dim,
395 | mlp_ratio=4.,
396 | proj_drop=0.,
397 | init_values=None,
398 | drop_path=0.,
399 | act_layer=nn.GELU,
400 | norm_layer=nn.LayerNorm,
401 | mlp_layer=None,
402 | prefix_token_length=0,
403 | ):
404 | super().__init__()
405 | self.norm2 = norm_layer(dim)
406 | if mlp_layer is DynamicLinearMlp:
407 | self.mlp = mlp_layer(
408 | in_features=dim,
409 | hidden_features=int(dim * mlp_ratio),
410 | act_layer=act_layer,
411 | drop=proj_drop,
412 | prefix_token_length=prefix_token_length,
413 | )
414 | else:
415 | self.mlp = mlp_layer(
416 | in_features=dim,
417 | hidden_features=int(dim * mlp_ratio),
418 | act_layer=act_layer,
419 | drop=proj_drop,
420 | )
421 | self.ls2 = GateLayer(dim, init_values=init_values)
422 | self.drop_path2 = DropPath(
423 | drop_path) if drop_path > 0. else nn.Identity()
424 |
425 | def forward(self, x, prefix_seq_len=None):
426 | if prefix_seq_len is not None:
427 | x = x + \
428 | self.drop_path2(
429 | self.ls2(self.mlp(self.norm2(x), prefix_seq_len=prefix_seq_len)))
430 | else:
431 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
432 | return x
433 |
434 |
435 | class BasicBlock(nn.Module):
436 | def __init__(
437 | self,
438 | dim,
439 | num_heads,
440 | mlp_ratio=8.,
441 | qkv_bias=False,
442 | qk_norm=False,
443 | proj_drop=0.,
444 | attn_drop=0.,
445 | init_values=None,
446 | drop_path=0.,
447 | act_layer=nn.GELU,
448 | norm_layer=nn.LayerNorm,
449 | prefix_token_length=0,
450 | ):
451 | super().__init__()
452 | self.seq_att_block = SeqAttBlock(dim=dim, num_heads=num_heads,
453 | qkv_bias=qkv_bias, qk_norm=qk_norm,
454 | attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop,
455 | drop_path=drop_path, norm_layer=norm_layer)
456 |
457 | self.var_att_block = VarAttBlock(dim=dim, num_heads=num_heads,
458 | qkv_bias=qkv_bias, qk_norm=qk_norm,
459 | attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop,
460 | drop_path=drop_path, norm_layer=norm_layer)
461 |
462 | self.dynamic_mlp = MLPBlock(dim=dim, mlp_ratio=mlp_ratio, mlp_layer=DynamicLinearMlp,
463 | proj_drop=proj_drop, init_values=init_values, drop_path=drop_path,
464 | act_layer=act_layer, norm_layer=norm_layer,
465 | prefix_token_length=prefix_token_length)
466 |
467 | def forward(self, x, prefix_seq_len, attn_mask):
468 | x = self.seq_att_block(x, attn_mask)
469 | x = self.var_att_block(x)
470 | x = self.dynamic_mlp(x, prefix_seq_len=prefix_seq_len)
471 | return x
472 |
473 |
474 | class PatchEmbedding(nn.Module):
475 | def __init__(self, d_model, patch_len, stride, padding, dropout):
476 | super(PatchEmbedding, self).__init__()
477 | # Patching
478 | self.patch_len = patch_len
479 | self.stride = stride
480 | assert self.patch_len == self.stride, "non-overlap"
481 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
482 | self.dropout = nn.Dropout(dropout)
483 |
484 | def forward(self, x):
485 | n_vars = x.shape[1]
486 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
487 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
488 | x = self.value_embedding(x)
489 | return self.dropout(x), n_vars
490 |
491 |
492 | class CLSHead(nn.Module):
493 | def __init__(self, d_model, head_dropout=0):
494 | super().__init__()
495 | d_mid = d_model
496 | self.proj_in = nn.Linear(d_model, d_mid)
497 | self.cross_att = CrossAttention(d_mid)
498 |
499 | self.mlp = MLPBlock(dim=d_mid, mlp_ratio=8, mlp_layer=Mlp,
500 | proj_drop=head_dropout, init_values=None, drop_path=0.0,
501 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,
502 | prefix_token_length=None)
503 |
504 | def forward(self, x, category_token=None, return_feature=False):
505 | x = self.proj_in(x)
506 | B, V, L, C = x.shape
507 | x = x.view(-1, L, C)
508 | cls_token = x[:, -1:]
509 | cls_token = self.cross_att(x, query=cls_token)
510 | cls_token = cls_token.reshape(B, V, -1, C)
511 |
512 | cls_token = self.mlp(cls_token)
513 | if return_feature:
514 | return cls_token
515 | m = category_token.shape[2]
516 | cls_token = cls_token.expand(B, V, m, C)
517 | distance = torch.einsum('nvkc,nvmc->nvm', cls_token, category_token)
518 |
519 | distance = distance.mean(dim=1)
520 | return distance
521 |
522 |
523 | class ForecastHead(nn.Module):
524 | def __init__(self, d_model, patch_len, stride, pad, head_dropout=0, prefix_token_length=None):
525 | super().__init__()
526 | d_mid = d_model
527 | self.proj_in = nn.Linear(d_model, d_mid)
528 | self.mlp = Mlp(
529 | in_features=d_model,
530 | hidden_features=int(d_model * 4),
531 | act_layer=nn.GELU,
532 | drop=head_dropout,
533 | )
534 | self.proj_out = nn.Linear(d_model, patch_len)
535 | self.pad = pad
536 | self.patch_len = patch_len
537 | self.stride = stride
538 | self.pos_proj = DynamicLinear(
539 | in_features=128, out_features=128, fixed_in=prefix_token_length)
540 |
541 | def forward(self, x_full, pred_len, token_len):
542 | x_full = self.proj_in(x_full)
543 | x_pred = x_full[:, :, -token_len:]
544 | x = x_full.transpose(-1, -2)
545 | x = self.pos_proj(x, token_len)
546 | x = x.transpose(-1, -2)
547 | x = x + x_pred
548 | x = self.mlp(x)
549 | x = self.proj_out(x)
550 |
551 | bs, n_vars = x.shape[0], x.shape[1]
552 | x = x.reshape(-1, x.shape[-2], x.shape[-1])
553 | x = x.permute(0, 2, 1)
554 | x = torch.nn.functional.fold(x, output_size=(
555 | pred_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1))
556 | x = x.squeeze(dim=-1)
557 | x = x.reshape(bs, n_vars, -1)
558 | x = x.permute(0, 2, 1)
559 | return x
560 |
561 |
562 | class Model(nn.Module):
563 | """
564 | UniTS: Building a Unified Time Series Model
565 | """
566 |
567 | def __init__(self, args, configs_list, pretrain=False):
568 | super().__init__()
569 |
570 | if pretrain:
571 | self.right_prob = args.right_prob
572 | self.min_mask_ratio = args.min_mask_ratio
573 | self.max_mask_ratio = args.max_mask_ratio
574 |
575 | # Tokens settings
576 | self.num_task = len(configs_list)
577 | self.prompt_token =nn.Parameter(torch.zeros(1, 1, args.prompt_num, args.d_model))
578 | torch.nn.init.normal_(self.prompt_token, std=.02)
579 | self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, args.d_model))
580 | self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, args.d_model))
581 | torch.nn.init.normal_(self.cls_token, std=.02)
582 | self.category_tokens = nn.ParameterDict({})
583 |
584 | for i in range(self.num_task):
585 | task_data_name = configs_list[i][0]
586 | if configs_list[i][1]['task_name'] == 'classification':
587 | self.category_tokens[task_data_name] = torch.zeros(
588 | 1, configs_list[i][1]['enc_in'], configs_list[i][1]['num_class'], args.d_model)
589 | torch.nn.init.normal_(
590 | self.category_tokens[task_data_name], std=.02)
591 |
592 | self.cls_nums = {}
593 | for i in range(self.num_task):
594 | task_data_name = configs_list[i][0]
595 | if configs_list[i][1]['task_name'] == 'classification':
596 | self.cls_nums[task_data_name] = configs_list[i][1]['num_class']
597 | elif configs_list[i][1]['task_name'] == 'long_term_forecast':
598 | remainder = configs_list[i][1]['seq_len'] % args.patch_len
599 | if remainder == 0:
600 | padding = 0
601 | else:
602 | padding = args.patch_len - remainder
603 | input_token_len = calculate_unfold_output_length(
604 | configs_list[i][1]['seq_len']+padding, args.stride, args.patch_len)
605 | input_pad = args.stride * \
606 | (input_token_len - 1) + args.patch_len - \
607 | configs_list[i][1]['seq_len']
608 | pred_token_len = calculate_unfold_output_length(
609 | configs_list[i][1]['pred_len']-input_pad, args.stride, args.patch_len)
610 | real_len = configs_list[i][1]['seq_len'] + \
611 | configs_list[i][1]['pred_len']
612 | self.cls_nums[task_data_name] = [pred_token_len,
613 | configs_list[i][1]['pred_len'], real_len]
614 |
615 | self.configs_list = configs_list
616 |
617 | ### model settings ###
618 | self.prompt_num = args.prompt_num
619 | self.stride = args.stride
620 | self.pad = args.stride
621 | self.patch_len = args.patch_len
622 |
623 | # input processing
624 | self.patch_embeddings = PatchEmbedding(
625 | args.d_model, args.patch_len, args.stride, args.stride, args.dropout)
626 | self.position_embedding = LearnablePositionalEmbedding(args.d_model)
627 | self.prompt2forecat = DynamicLinear(128, 128, fixed_in=args.prompt_num)
628 |
629 | # basic blocks
630 | self.block_num = args.e_layers
631 | self.blocks = nn.ModuleList(
632 | [BasicBlock(dim=args.d_model, num_heads=args.n_heads, qkv_bias=False, qk_norm=False,
633 | mlp_ratio=8., proj_drop=args.dropout, attn_drop=0., drop_path=0.,
634 | init_values=None, prefix_token_length=args.prompt_num) for l in range(args.e_layers)]
635 | )
636 |
637 | # output processing
638 | self.cls_head = CLSHead(args.d_model, head_dropout=args.dropout)
639 | self.forecast_head = ForecastHead(
640 | args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=args.prompt_num, head_dropout=args.dropout)
641 | if pretrain:
642 | self.pretrain_head = ForecastHead(
643 | args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=1, head_dropout=args.dropout)
644 |
645 | def tokenize(self, x, mask=None):
646 | # Normalization from Non-stationary Transformer
647 | means = x.mean(1, keepdim=True).detach()
648 | x = x - means
649 | if mask is not None:
650 | x = x.masked_fill(mask == 0, 0)
651 | stdev = torch.sqrt(torch.sum(x * x, dim=1) /
652 | torch.sum(mask == 1, dim=1) + 1e-5)
653 | stdev = stdev.unsqueeze(dim=1)
654 | else:
655 | stdev = torch.sqrt(
656 | torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
657 | x /= stdev
658 | x = x.permute(0, 2, 1)
659 | remainder = x.shape[2] % self.patch_len
660 | if remainder != 0:
661 | padding = self.patch_len - remainder
662 | x = F.pad(x, (0, padding))
663 | else:
664 | padding = 0
665 | x, n_vars = self.patch_embeddings(x)
666 | return x, means, stdev, n_vars, padding
667 |
668 | def prepare_prompt(self, x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name=None, mask=None):
669 | x = torch.reshape(
670 | x, (-1, n_vars, x.shape[-2], x.shape[-1]))
671 | # append prompt tokens
672 | this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)
673 |
674 | if task_name == 'forecast':
675 | this_mask_prompt = task_prompt.repeat(
676 | x.shape[0], 1, task_prompt_num, 1)
677 | init_full_input = torch.cat(
678 | (this_prompt, x, this_mask_prompt), dim=-2)
679 | init_mask_prompt = self.prompt2forecat(init_full_input.transpose(
680 | -1, -2), init_full_input.shape[2]-prefix_prompt.shape[2]).transpose(-1, -2)
681 | this_function_prompt = init_mask_prompt[:, :, -task_prompt_num:]
682 | x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
683 | x[:, :, self.prompt_num:] = x[:, :, self.prompt_num:] + \
684 | self.position_embedding(x[:, :, self.prompt_num:])
685 | elif task_name == 'classification':
686 | this_function_prompt = task_prompt.repeat(x.shape[0], 1, 1, 1)
687 | x = x + self.position_embedding(x)
688 | x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
689 | elif task_name == 'imputation':
690 | # fill the masked parts with mask tokens
691 | # for imputation, masked is 0, unmasked is 1, so here to reverse mask
692 | mask = 1-mask
693 | mask = mask.permute(0, 2, 1)
694 | mask = self.mark2token(mask)
695 | mask_repeat = mask.unsqueeze(dim=-1)
696 |
697 | mask_token = task_prompt
698 | mask_repeat = mask_repeat.repeat(1, 1, 1, x.shape[-1])
699 | x = x * (1-mask_repeat) + mask_token * mask_repeat
700 |
701 | init_full_input = torch.cat((this_prompt, x), dim=-2)
702 | init_mask_prompt = self.prompt2forecat(
703 | init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)
704 | # keep the unmasked tokens and fill the masked ones with init_mask_prompt.
705 | x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat
706 | x = x + self.position_embedding(x)
707 | x = torch.cat((this_prompt, x), dim=2)
708 | elif task_name == 'anomaly_detection':
709 | x = x + self.position_embedding(x)
710 | x = torch.cat((this_prompt, x), dim=2)
711 |
712 | return x
713 |
714 | def mark2token(self, x_mark):
715 | x_mark = x_mark.unfold(
716 | dimension=-1, size=self.patch_len, step=self.stride)
717 | x_mark = x_mark.mean(dim=-1)
718 | x_mark = (x_mark > 0).float()
719 | return x_mark
720 |
721 | def backbone(self, x, prefix_len, seq_len):
722 | attn_mask = None
723 | for block in self.blocks:
724 | x = block(x, prefix_seq_len=prefix_len +
725 | seq_len, attn_mask=attn_mask)
726 | return x
727 |
728 | def forecast(self, x, x_mark, task_id):
729 | task_data_name = self.configs_list[task_id][0]
730 | task_prompt_num = self.cls_nums[task_data_name][0]
731 | task_seq_num = self.cls_nums[task_data_name][1]
732 | real_seq_len = self.cls_nums[task_data_name][2]
733 | x, means, stdev, n_vars, _ = self.tokenize(x)
734 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1)
735 | task_prompt = self.mask_token.repeat(1,n_vars,1,1)
736 |
737 | x = self.prepare_prompt(
738 | x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='forecast')
739 |
740 | seq_token_len = x.shape[-2]-prefix_prompt.shape[2]
741 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
742 |
743 | x = self.forecast_head(
744 | x, real_seq_len, seq_token_len)
745 | x = x[:, -task_seq_num:]
746 |
747 | # De-Normalization from Non-stationary Transformer
748 | x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
749 | x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
750 |
751 | return x
752 |
753 | def classification(self, x, x_mark, task_id):
754 | task_data_name = self.configs_list[task_id][0]
755 | task_prompt_num = 1
756 | category_token = self.category_tokens[task_data_name]
757 |
758 | x, means, stdev, n_vars, _ = self.tokenize(x)
759 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1)
760 | task_prompt = self.cls_token.repeat(1,n_vars,1,1)
761 |
762 | seq_len = x.shape[-2]
763 |
764 | x = self.prepare_prompt(
765 | x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='classification')
766 |
767 | x = self.backbone(x, prefix_prompt.shape[2], seq_len)
768 |
769 | x = self.cls_head(x, category_token)
770 |
771 | return x
772 |
773 | def imputation(self, x, x_mark, mask, task_id):
774 |
775 | seq_len = x.shape[1]
776 | x, means, stdev, n_vars, padding = self.tokenize(x, mask)
777 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1)
778 | task_prompt = self.mask_token.repeat(1,n_vars,1,1)
779 |
780 | x = self.prepare_prompt(
781 | x, n_vars, prefix_prompt, task_prompt, None, mask=mask, task_name='imputation')
782 | seq_token_len = x.shape[-2]-prefix_prompt.shape[2]
783 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
784 |
785 | x = self.forecast_head(
786 | x, seq_len+padding, seq_token_len)
787 | x = x[:, :seq_len]
788 |
789 | # De-Normalization from Non-stationary Transformer
790 | x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
791 | x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
792 |
793 | return x
794 |
795 | def anomaly_detection(self, x, x_mark, task_id):
796 | seq_len = x.shape[1]
797 | x, means, stdev, n_vars, padding = self.tokenize(x)
798 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1)
799 |
800 | x = self.prepare_prompt(x, n_vars, prefix_prompt,
801 | None, None, task_name='anomaly_detection')
802 | seq_token_len = x.shape[-2]-prefix_prompt.shape[2]
803 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
804 |
805 | x = self.forecast_head(
806 | x, seq_len+padding, seq_token_len)
807 | x = x[:, :seq_len]
808 |
809 | # De-Normalization from Non-stationary Transformer
810 | x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
811 | x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1))
812 |
813 | return x
814 |
815 | def random_masking(self, x, min_mask_ratio, max_mask_ratio):
816 | """
817 | Perform per-sample random masking.
818 | """
819 | N, V, L, D = x.shape # batch, var, length, dim
820 |
821 | # Calculate mask ratios and lengths to keep for each sample in the batch
822 | mask_ratios = torch.rand(N, device=x.device) * \
823 | (max_mask_ratio - min_mask_ratio) + min_mask_ratio
824 | len_keeps = (L * (1 - mask_ratios)).long()
825 |
826 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
827 |
828 | # sort noise for each sample
829 | # ascend: small is keep, large is remove
830 | ids_shuffle = torch.argsort(noise, dim=1)
831 | ids_restore = torch.argsort(ids_shuffle, dim=1)
832 |
833 | # generate the binary mask: 0 is keep, 1 is remove
834 | mask = torch.ones([N, L], device=x.device)
835 |
836 | # Create a range tensor and compare with len_keeps for mask generation
837 | range_tensor = torch.arange(L, device=x.device).expand(N, L)
838 | mask = (range_tensor >= len_keeps.unsqueeze(1))
839 |
840 | # unshuffle to get the binary mask
841 | mask = torch.gather(mask, dim=1, index=ids_restore)
842 | mask = mask.float()
843 |
844 | return mask
845 |
846 | def right_masking(self, x, min_mask_ratio, max_mask_ratio):
847 | N, V, L, D = x.shape # batch, var, length, dim
848 |
849 | # Randomly choose a mask ratio for each sample within the specified range
850 | mask_ratios = torch.rand(N, device=x.device) * \
851 | (max_mask_ratio - min_mask_ratio) + min_mask_ratio
852 | len_keeps = (L * (1 - mask_ratios)).long()
853 |
854 | # Binary mask creation without a for loop
855 | len_keeps_matrix = len_keeps.unsqueeze(1).expand(N, L)
856 | indices = torch.arange(L, device=x.device).expand_as(len_keeps_matrix)
857 | mask = indices >= len_keeps_matrix
858 | mask = mask.float()
859 |
860 | return mask
861 |
862 | def choose_masking(self, x, right_prob, min_mask_ratio, max_mask_ratio):
863 | # Generate a random number to decide which masking function to use
864 | if torch.rand(1).item() > right_prob:
865 | return self.random_masking(x, min_mask_ratio, max_mask_ratio)
866 | else:
867 | return self.right_masking(x, min_mask_ratio, max_mask_ratio)
868 |
869 | def get_mask_seq(self, mask, seq_len):
870 | mask_seq = mask.unsqueeze(dim=-1).repeat(1, 1, self.patch_len)
871 | mask_seq = mask_seq.permute(0, 2, 1)
872 | mask_seq = mask_seq.masked_fill(mask_seq == 0, -1e9)
873 | # Fold operation
874 | mask_seq = torch.nn.functional.fold(mask_seq, output_size=(
875 | seq_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1))
876 | # Apply threshold to bring back to 0/1 values
877 | mask_seq = (mask_seq > 0).float()
878 | mask_seq = mask_seq.squeeze(dim=-1).squeeze(dim=1)
879 | return mask_seq
880 |
881 | def pretraining(self, x, x_mark, task_id, enable_mask=False):
882 | seq_len = x.shape[1]
883 | x, means, stdev, n_vars, padding = self.tokenize(x)
884 | seq_token_len = x.shape[-2]
885 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1)
886 | mask_token = self.mask_token.repeat(1,n_vars,1,1)
887 | cls_token = self.cls_token.repeat(1,n_vars,1,1)
888 |
889 | # append prompt tokens
890 | x = torch.reshape(
891 | x, (-1, n_vars, x.shape[-2], x.shape[-1]))
892 | # prepare prompts
893 | this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1)
894 |
895 | if enable_mask:
896 | mask = self.choose_masking(x, self.right_prob,
897 | self.min_mask_ratio, self.max_mask_ratio)
898 | mask_repeat = mask.unsqueeze(dim=1).unsqueeze(dim=-1)
899 | mask_repeat = mask_repeat.repeat(1, x.shape[1], 1, x.shape[-1])
900 | x = x * (1-mask_repeat) + mask_token * mask_repeat # todo
901 |
902 | init_full_input = torch.cat((this_prompt, x), dim=-2)
903 | init_mask_prompt = self.prompt2forecat(
904 | init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2)
905 | # keep the unmasked tokens and fill the masked ones with init_mask_prompt.
906 | x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat
907 | x = x + self.position_embedding(x)
908 | mask_seq = self.get_mask_seq(mask, seq_len+padding)
909 | mask_seq = mask_seq[:, :seq_len]
910 | this_function_prompt = cls_token.repeat(x.shape[0], 1, 1, 1)
911 | x = torch.cat((this_prompt, x, this_function_prompt), dim=2)
912 |
913 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len)
914 |
915 | if enable_mask:
916 | mask_dec_out = self.forecast_head(
917 | x[:, :, :-1], seq_len+padding, seq_token_len)
918 | mask_dec_out = mask_dec_out[:, :seq_len]
919 | # De-Normalization from Non-stationary Transformer
920 | mask_dec_out = mask_dec_out * \
921 | (stdev[:, 0, :].unsqueeze(1).repeat(
922 | 1, mask_dec_out.shape[1], 1))
923 | mask_dec_out = mask_dec_out + \
924 | (means[:, 0, :].unsqueeze(1).repeat(
925 | 1, mask_dec_out.shape[1], 1))
926 | cls_dec_out = self.cls_head(x, return_feature=True)
927 | # detach grad of the forecasting on tokens
928 | fused_dec_out = torch.cat(
929 | (cls_dec_out, x[:, :, self.prompt_num:-1].detach()), dim=2)
930 | cls_dec_out = self.pretrain_head(
931 | fused_dec_out, seq_len+padding, seq_token_len)
932 | cls_dec_out = cls_dec_out[:, :seq_len]
933 | cls_dec_out = cls_dec_out * \
934 | (stdev[:, 0, :].unsqueeze(1).repeat(
935 | 1, cls_dec_out.shape[1], 1))
936 | cls_dec_out = cls_dec_out + \
937 | (means[:, 0, :].unsqueeze(1).repeat(
938 | 1, cls_dec_out.shape[1], 1))
939 |
940 | return cls_dec_out, mask_dec_out, mask_seq
941 | else:
942 | return cls_dec_out
943 |
944 | def forward(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None,
945 | mask=None, task_id=None, task_name=None, enable_mask=None):
946 | if task_name == 'long_term_forecast' or task_name == 'short_term_forecast':
947 | dec_out = self.forecast(x_enc, x_mark_enc, task_id)
948 | return dec_out # [B, L, D]
949 | if task_name == 'imputation':
950 | dec_out = self.imputation(
951 | x_enc, x_mark_enc, mask, task_id)
952 | return dec_out # [B, L, D]
953 | if task_name == 'anomaly_detection':
954 | dec_out = self.anomaly_detection(x_enc, x_mark_enc, task_id)
955 | return dec_out # [B, L, D]
956 | if task_name == 'classification':
957 | dec_out = self.classification(x_enc, x_mark_enc, task_id)
958 | return dec_out # [B, N]
959 | if 'pretrain' in task_name:
960 | dec_out = self.pretraining(x_enc, x_mark_enc, task_id,
961 | enable_mask=enable_mask)
962 | return dec_out
963 | return None
964 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/models/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.4.0
2 | matplotlib==3.7.0
3 | numpy==1.23.5
4 | pandas==1.5.3
5 | patool==1.12
6 | reformer-pytorch==1.4.4
7 | scikit-learn==1.2.2
8 | scipy==1.10.1
9 | sktime==0.16.1
10 | sympy==1.11.1
11 | tqdm==4.64.1
12 | pyyaml
13 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from exp.exp_sup import Exp_All_Task as Exp_All_Task_SUP
4 | import random
5 | import numpy as np
6 | import wandb
7 | from utils.ddp import is_main_process, init_distributed_mode
8 |
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser(description='UniTS supervised training')
12 |
13 | # basic config
14 | parser.add_argument('--task_name', type=str, required=False, default='ALL_task',
15 | help='task name')
16 | parser.add_argument('--is_training', type=int,
17 | required=True, default=1, help='status')
18 | parser.add_argument('--model_id', type=str, required=True,
19 | default='test', help='model id')
20 | parser.add_argument('--model', type=str, required=True, default='UniTS',
21 | help='model name')
22 |
23 | # data loader
24 | parser.add_argument('--data', type=str, required=False,
25 | default='All', help='dataset type')
26 | parser.add_argument('--features', type=str, default='M',
27 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
28 | parser.add_argument('--target', type=str, default='OT',
29 | help='target feature in S or MS task')
30 | parser.add_argument('--freq', type=str, default='h',
31 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
32 | parser.add_argument('--task_data_config_path', type=str,
33 | default='exp/all_task.yaml', help='root path of the task and data yaml file')
34 | parser.add_argument('--subsample_pct', type=float,
35 | default=None, help='subsample percent')
36 |
37 | # ddp
38 | parser.add_argument('--local-rank', type=int, help='local rank')
39 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
40 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
41 | parser.add_argument('--num_workers', type=int, default=0,
42 | help='data loader num workers')
43 | parser.add_argument("--memory_check", action="store_true", default=True)
44 | parser.add_argument("--large_model", action="store_true", default=True)
45 |
46 | # optimization
47 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
48 | parser.add_argument('--train_epochs', type=int,
49 | default=10, help='train epochs')
50 | parser.add_argument("--prompt_tune_epoch", type=int, default=0)
51 | parser.add_argument('--warmup_epochs', type=int,
52 | default=0, help='warmup epochs')
53 | parser.add_argument('--batch_size', type=int, default=32,
54 | help='batch size of train input data')
55 | parser.add_argument('--acc_it', type=int, default=1,
56 | help='acc iteration to enlarge batch size')
57 | parser.add_argument('--learning_rate', type=float,
58 | default=0.0001, help='optimizer learning rate')
59 | parser.add_argument('--min_lr', type=float, default=None,
60 | help='optimizer min learning rate')
61 | parser.add_argument('--weight_decay', type=float,
62 | default=0.0, help='optimizer weight decay')
63 | parser.add_argument('--layer_decay', type=float,
64 | default=None, help='optimizer layer decay')
65 | parser.add_argument('--des', type=str, default='test',
66 | help='exp description')
67 | parser.add_argument('--lradj', type=str,
68 | default='supervised', help='adjust learning rate')
69 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
70 | help='Clip gradient norm (default: None, no clipping)')
71 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
72 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/',
73 | help='save location of model checkpoints')
74 | parser.add_argument('--pretrained_weight', type=str, default=None,
75 | help='location of pretrained model checkpoints')
76 | parser.add_argument('--debug', type=str,
77 | default='enabled', help='disabled')
78 | parser.add_argument('--project_name', type=str,
79 | default='tsfm-multitask', help='wandb project name')
80 |
81 | # model settings
82 | parser.add_argument('--d_model', type=int, default=512,
83 | help='dimension of model')
84 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
85 | parser.add_argument('--e_layers', type=int, default=2,
86 | help='num of encoder layers')
87 | parser.add_argument("--share_embedding",
88 | action="store_true", default=False)
89 | parser.add_argument("--patch_len", type=int, default=16)
90 | parser.add_argument("--stride", type=int, default=8)
91 | parser.add_argument("--prompt_num", type=int, default=5)
92 | parser.add_argument('--fix_seed', type=int, default=None, help='seed')
93 |
94 | # task related settings
95 | # forecasting task
96 | parser.add_argument('--inverse', action='store_true',
97 | help='inverse output data', default=False)
98 |
99 | # inputation task
100 | parser.add_argument('--mask_rate', type=float,
101 | default=0.25, help='mask ratio')
102 |
103 | # anomaly detection task
104 | parser.add_argument('--anomaly_ratio', type=float,
105 | default=1.0, help='prior anomaly ratio (%)')
106 |
107 | # zero-shot-forecast-new-length
108 | parser.add_argument("--offset", type=int, default=0)
109 | parser.add_argument("--max_offset", type=int, default=0)
110 | parser.add_argument('--zero_shot_forecasting_new_length',
111 | type=str, default=None, help='unify')
112 |
113 | args = parser.parse_args()
114 | init_distributed_mode(args)
115 | if args.fix_seed is not None:
116 | random.seed(args.fix_seed)
117 | torch.manual_seed(args.fix_seed)
118 | np.random.seed(args.fix_seed)
119 |
120 | print('Args in experiment:')
121 | print(args)
122 | exp_name = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}'.format(
123 | args.task_name,
124 | args.model_id,
125 | args.model,
126 | args.data,
127 | args.features,
128 | args.d_model,
129 | args.e_layers,
130 | args.des)
131 |
132 | if int(args.prompt_tune_epoch) != 0:
133 | exp_name = 'Ptune'+str(args.prompt_tune_epoch)+'_'+exp_name
134 | print(exp_name)
135 |
136 | if is_main_process():
137 | wandb.init(
138 | name=exp_name,
139 | # set the wandb project where this run will be logged
140 | project=args.project_name,
141 | # track hyperparameters and run metadata
142 | config=args,
143 | mode=args.debug,
144 | )
145 |
146 | Exp = Exp_All_Task_SUP
147 |
148 | if args.is_training:
149 | for ii in range(args.itr):
150 | # setting record of experiments
151 | setting = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}_{}'.format(
152 | args.task_name,
153 | args.model_id,
154 | args.model,
155 | args.data,
156 | args.features,
157 | args.d_model,
158 | args.e_layers,
159 | args.des, ii)
160 |
161 | exp = Exp(args) # set experiments
162 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
163 | exp.train(setting)
164 | else:
165 | ii = 0
166 | setting = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}_{}'.format(
167 | args.task_name,
168 | args.model_id,
169 | args.model,
170 | args.data,
171 | args.features,
172 | args.d_model,
173 | args.e_layers,
174 | args.des, ii)
175 |
176 | exp = Exp(args) # set experiments
177 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
178 | exp.test(setting, load_pretrain=True)
179 | torch.cuda.empty_cache()
180 |
--------------------------------------------------------------------------------
/run_pretrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from exp.exp_pretrain import Exp_All_Task as Exp_All_Task_SSL
4 | import random
5 | import numpy as np
6 | import wandb
7 | from utils.ddp import is_main_process, init_distributed_mode
8 |
9 | if __name__ == '__main__':
10 | parser = argparse.ArgumentParser(description='UniTS Pretrain')
11 | parser.add_argument('--fix_seed', type=int, default=None, help='seed')
12 | # basic config
13 | parser.add_argument('--task_name', type=str, required=False, default='ALL_task',
14 | help='task name')
15 | parser.add_argument('--is_training', type=int,
16 | required=True, default=1, help='status')
17 | parser.add_argument('--model_id', type=str, required=True,
18 | default='test', help='model id')
19 | parser.add_argument('--model', type=str, required=True, default='UniTS',
20 | help='model name')
21 |
22 | # data loader
23 | parser.add_argument('--data', type=str, required=False,
24 | default='All', help='dataset type')
25 | parser.add_argument('--features', type=str, default='M',
26 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
27 | parser.add_argument('--target', type=str, default='OT',
28 | help='target feature in S or MS task')
29 | parser.add_argument('--freq', type=str, default='h',
30 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
31 | parser.add_argument('--task_data_config_path', type=str,
32 | default='exp/all_task_pretrain.yaml', help='root path of the task and data yaml file')
33 | parser.add_argument('--subsample_pct', type=float,
34 | default=None, help='subsample percent')
35 |
36 | # pretrain
37 | parser.add_argument('--right_prob', type=float,
38 | default=1.0, help='right mask prob')
39 | parser.add_argument('--min_mask_ratio', type=float,
40 | default=0.5, help='min right mask prob')
41 | parser.add_argument('--max_mask_ratio', type=float,
42 | default=0.8, help='max right mask prob')
43 | parser.add_argument('--min_keep_ratio', type=float, default=None,
44 | help='min crop ratio for various length in pretraining')
45 |
46 | # ddp
47 | parser.add_argument('--local-rank', type=int, help='local rank')
48 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
49 | distributed training; see https://pytorch.org/docs/stable/distributed.html""")
50 | parser.add_argument('--num_workers', type=int, default=0,
51 | help='data loader num workers')
52 |
53 | # optimization
54 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
55 | parser.add_argument('--train_epochs', type=int,
56 | default=10, help='train epochs')
57 | parser.add_argument('--warmup_epochs', type=int,
58 | default=0, help='warmup epochs')
59 | parser.add_argument('--batch_size', type=int, default=32,
60 | help='batch size of train input data')
61 | parser.add_argument('--acc_it', type=int, default=32,
62 | help='acc iteration to enlarge batch size')
63 | parser.add_argument('--learning_rate', type=float,
64 | default=0.0001, help='optimizer learning rate')
65 | parser.add_argument('--min_lr', type=float, default=1e-6,
66 | help='optimizer learning rate')
67 | parser.add_argument('--beta2', type=float,
68 | default=0.999, help='optimizer beta2')
69 | parser.add_argument('--weight_decay', type=float,
70 | default=0.0, help='optimizer weight decay')
71 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
72 | parser.add_argument('--eps', type=float, default=1e-08,
73 | help='eps for optimizer')
74 | parser.add_argument('--des', type=str, default='test',
75 | help='exp description')
76 | parser.add_argument('--debug', type=str,
77 | default='enabled', help='disabled')
78 | parser.add_argument('--clip_grad', type=float, default=None, help="""Maximal parameter
79 | gradient norm if using gradient clipping.""")
80 | parser.add_argument('--checkpoints', type=str,
81 | default='./checkpoints/', help='location of model checkpoints')
82 |
83 | parser.add_argument("--memory_check", action="store_true", default=True)
84 | parser.add_argument("--large_model", action="store_true", default=True)
85 |
86 | # model settings
87 | parser.add_argument('--d_model', type=int, default=512,
88 | help='dimension of model')
89 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
90 | parser.add_argument('--e_layers', type=int, default=2,
91 | help='num of encoder layers')
92 | parser.add_argument("--patch_len", type=int, default=16)
93 | parser.add_argument("--stride", type=int, default=8)
94 | parser.add_argument("--prompt_num", type=int, default=10)
95 |
96 | args = parser.parse_args()
97 | init_distributed_mode(args)
98 |
99 | print('Args in experiment:')
100 | print(args)
101 | if args.fix_seed is not None:
102 | random.seed(args.fix_seed)
103 | torch.manual_seed(args.fix_seed)
104 | np.random.seed(args.fix_seed)
105 | exp_name = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}'.format(
106 | args.task_name,
107 | args.model_id,
108 | args.model,
109 | args.data,
110 | args.features,
111 | args.d_model,
112 | args.e_layers,
113 | args.des)
114 |
115 | if is_main_process():
116 | wandb.init(
117 | name=exp_name,
118 | # set the wandb project where this run will be logged
119 | project="pretrain",
120 | # track hyperparameters and run metadata
121 | config=args,
122 | mode=args.debug,
123 | )
124 | Exp = Exp_All_Task_SSL
125 |
126 | if args.is_training:
127 | for ii in range(args.itr):
128 | # setting record of experiments
129 | setting = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}_{}'.format(
130 | args.task_name,
131 | args.model_id,
132 | args.model,
133 | args.data,
134 | args.features,
135 | args.d_model,
136 | args.e_layers,
137 | args.des, ii)
138 |
139 | exp = Exp(args) # set experiments
140 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
141 | exp.train(setting)
142 |
143 | torch.cuda.empty_cache()
144 |
--------------------------------------------------------------------------------
/scripts/few_shot_anomaly_detection/UniTS_finetune_few_shot_anomaly_detection.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=anomaly_detection
4 | exp_name=finetune_few_shot_anomaly_detection_pct05
5 |
6 | # Path to the supervised checkpoint
7 | # get ssl pretrained checkpoint: scripts/pretrain_prompt_learning/UniTS_pretrain_x32.sh
8 | ckpt_path=newcheckpoints/units_x32_pretrain_checkpoint.pth
9 | random_port=$((RANDOM % 9000 + 1000))
10 |
11 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
12 | --fix_seed 2021 \
13 | --is_training 1 \
14 | --subsample_pct 0.05 \
15 | --model_id $exp_name \
16 | --pretrained_weight $ckpt_path \
17 | --model $model_name \
18 | --prompt_num 10 \
19 | --patch_len 16 \
20 | --stride 16 \
21 | --e_layers 3 \
22 | --d_model 32 \
23 | --des 'Exp' \
24 | --itr 1 \
25 | --lradj finetune_anl \
26 | --learning_rate 5e-4 \
27 | --weight_decay 1e-3 \
28 | --train_epochs 10 \
29 | --batch_size 32 \
30 | --acc_it 32 \
31 | --dropout 0 \
32 | --debug $wandb_mode \
33 | --project_name $project_name \
34 | --clip_grad 100 \
35 | --task_data_config_path data_provider/anomaly_detection.yaml
--------------------------------------------------------------------------------
/scripts/few_shot_anomaly_detection/UniTS_prompt_tuning_few_shot_anomaly_detection.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=anomaly_detection
4 | exp_name=prompt_tuning_few_shot_anomaly_detection_pct05
5 |
6 |
7 | # Path to the supervised checkpoint
8 | # get ssl pretrained checkpoint: scripts/pretrain_prompt_learning/UniTS_pretrain_x32.sh
9 | ckpt_path=newcheckpoints/units_x32_pretrain_checkpoint.pth
10 | random_port=$((RANDOM % 9000 + 1000))
11 |
12 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
13 | --fix_seed 2021 \
14 | --is_training 1 \
15 | --subsample_pct 0.05 \
16 | --model_id $exp_name \
17 | --pretrained_weight $ckpt_path \
18 | --model $model_name \
19 | --prompt_num 10 \
20 | --patch_len 16 \
21 | --stride 16 \
22 | --e_layers 3 \
23 | --d_model 32 \
24 | --des 'Exp' \
25 | --itr 1 \
26 | --lradj prompt_tuning \
27 | --learning_rate 5e-5 \
28 | --weight_decay 1e-2 \
29 | --train_epochs 0 \
30 | --prompt_tune_epoch 10 \
31 | --batch_size 32 \
32 | --acc_it 32 \
33 | --dropout 0.0 \
34 | --debug $wandb_mode \
35 | --project_name $project_name \
36 | --clip_grad 100 \
37 | --task_data_config_path data_provider/anomaly_detection.yaml
--------------------------------------------------------------------------------
/scripts/few_shot_imputation/UniTS_finetune_few_shot_imputation_mask025.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_imputation
4 | exp_name=fewshot_imputation_finetune_mask025
5 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth
6 |
7 | random_port=$((RANDOM % 9000 + 1000))
8 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
9 | --is_training 1 \
10 | --fix_seed 2021 \
11 | --model_id $exp_name \
12 | --subsample_pct 0.1 \
13 | --mask_rate 0.25 \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model 64 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --lradj finetune_imp \
23 | --learning_rate 3e-4 \
24 | --weight_decay 5e-6 \
25 | --train_epochs 20 \
26 | --batch_size 32 \
27 | --acc_it 32 \
28 | --clip_grad 1.0 \
29 | --debug $wandb_mode \
30 | --project_name $project_name \
31 | --pretrained_weight $ckpt_path \
32 | --task_data_config_path data_provider/imputation.yaml \
--------------------------------------------------------------------------------
/scripts/few_shot_imputation/UniTS_finetune_few_shot_imputation_mask050.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_imputation
4 | exp_name=fewshot_imputation_finetune_mask050
5 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth
6 |
7 | random_port=$((RANDOM % 9000 + 1000))
8 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
9 | --is_training 1 \
10 | --fix_seed 2021 \
11 | --model_id $exp_name \
12 | --subsample_pct 0.1 \
13 | --mask_rate 0.50 \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model 64 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --lradj finetune_imp \
23 | --learning_rate 3e-4 \
24 | --weight_decay 5e-6 \
25 | --train_epochs 20 \
26 | --batch_size 32 \
27 | --acc_it 32 \
28 | --clip_grad 1.0 \
29 | --debug $wandb_mode \
30 | --project_name $project_name \
31 | --pretrained_weight $ckpt_path \
32 | --task_data_config_path data_provider/imputation.yaml \
--------------------------------------------------------------------------------
/scripts/few_shot_imputation/UniTS_prompt_tuning_few_shot_imputation_mask025.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_imputation
4 | exp_name=fewshot_imputation_prompt_tuning_mask025
5 |
6 | # Path to the SSL pre-trained checkpoint
7 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth
8 |
9 | random_port=$((RANDOM % 9000 + 1000))
10 | ckpt_path=newcheckpoints/units_prompt_tuning_few_shot_imputation_mask025.pth
11 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
12 | --is_training 1 \
13 | --fix_seed 2021 \
14 | --model_id $exp_name \
15 | --subsample_pct 0.1 \
16 | --mask_rate 0.25 \
17 | --model $model_name \
18 | --prompt_num 10 \
19 | --patch_len 16 \
20 | --stride 16 \
21 | --e_layers 3 \
22 | --d_model 128 \
23 | --des 'Exp' \
24 | --itr 1 \
25 | --prompt_tune_epoch 20 \
26 | --train_epochs 0 \
27 | --lradj prompt_tuning \
28 | --learning_rate 5e-3 \
29 | --weight_decay 0 \
30 | --batch_size 32 \
31 | --acc_it 32 \
32 | --clip_grad 1.0 \
33 | --dropout 0 \
34 | --debug $wandb_mode \
35 | --project_name $project_name \
36 | --pretrained_weight $ckpt_path \
37 | --task_data_config_path data_provider/imputation.yaml \
--------------------------------------------------------------------------------
/scripts/few_shot_imputation/UniTS_prompt_tuning_few_shot_imputation_mask050.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_imputation
4 | exp_name=fewshot_imputation_prompt_tuning_mask050
5 |
6 | # Path to the SSL pre-trained checkpoint
7 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth
8 |
9 | random_port=$((RANDOM % 9000 + 1000))
10 |
11 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
12 | --is_training 1 \
13 | --fix_seed 2021 \
14 | --model_id $exp_name \
15 | --subsample_pct 0.1 \
16 | --mask_rate 0.50 \
17 | --model $model_name \
18 | --prompt_num 10 \
19 | --patch_len 16 \
20 | --stride 16 \
21 | --e_layers 3 \
22 | --d_model 128 \
23 | --des 'Exp' \
24 | --itr 1 \
25 | --prompt_tune_epoch 20 \
26 | --train_epochs 0 \
27 | --lradj prompt_tuning \
28 | --learning_rate 5e-3 \
29 | --weight_decay 0 \
30 | --batch_size 32 \
31 | --acc_it 32 \
32 | --clip_grad 1.0 \
33 | --dropout 0 \
34 | --debug $wandb_mode \
35 | --project_name $project_name \
36 | --pretrained_weight $ckpt_path \
37 | --task_data_config_path data_provider/imputation.yaml \
--------------------------------------------------------------------------------
/scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct05.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_newdata
4 | exp_name=fewshot_newdata_finetune_pct05
5 | random_port=$((RANDOM % 9000 + 1000))
6 | # Path to the supervised checkpoint
7 | # get supervised checkpoint: scripts/supervised/UniTS_supervised_x64.sh
8 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth
9 | torchrun --nnodes 1 --master_port $random_port run.py \
10 | --is_training 1 \
11 | --fix_seed 2021 \
12 | --model_id $exp_name \
13 | --subsample_pct 0.05 \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model 64 \
20 | --des 'Exp' \
21 | --train_epochs 5 \
22 | --learning_rate 1e-4 \
23 | --weight_decay 1e-5 \
24 | --lradj supervised \
25 | --dropout 0.1 \
26 | --acc_it 8 \
27 | --clip_grad 100 \
28 | --debug $wandb_mode \
29 | --project_name $project_name \
30 | --pretrained_weight $ckpt_path \
31 | --task_data_config_path data_provider/fewshot_new_task.yaml
32 |
33 |
--------------------------------------------------------------------------------
/scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct15.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_newdata
4 | exp_name=fewshot_newdata_finetune_pct15
5 | random_port=$((RANDOM % 9000 + 1000))
6 | # Path to the supervised checkpoint
7 | # get supervised checkpoint: scripts/supervised/UniTS_supervised_x64.sh
8 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth
9 | torchrun --nnodes 1 --master_port $random_port run.py \
10 | --is_training 1 \
11 | --fix_seed 2021 \
12 | --model_id $exp_name \
13 | --subsample_pct 0.15 \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model 64 \
20 | --des 'Exp' \
21 | --train_epochs 5 \
22 | --learning_rate 1e-4 \
23 | --weight_decay 1e-5 \
24 | --lradj supervised \
25 | --dropout 0.1 \
26 | --acc_it 8 \
27 | --clip_grad 100 \
28 | --debug $wandb_mode \
29 | --project_name $project_name \
30 | --pretrained_weight $ckpt_path \
31 | --task_data_config_path data_provider/fewshot_new_task.yaml
32 |
33 |
--------------------------------------------------------------------------------
/scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct20.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_newdata
4 | exp_name=fewshot_newdata_finetune_pct20
5 | random_port=$((RANDOM % 9000 + 1000))
6 | # Path to the supervised checkpoint
7 | # get supervised checkpoint: scripts/supervised/UniTS_supervised_x64.sh
8 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth
9 | torchrun --nnodes 1 --master_port $random_port run.py \
10 | --is_training 1 \
11 | --fix_seed 2021 \
12 | --model_id $exp_name \
13 | --subsample_pct 0.20 \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model 64 \
20 | --des 'Exp' \
21 | --train_epochs 5 \
22 | --learning_rate 1e-4 \
23 | --weight_decay 1e-5 \
24 | --lradj supervised \
25 | --dropout 0.1 \
26 | --acc_it 8 \
27 | --clip_grad 100 \
28 | --debug $wandb_mode \
29 | --project_name $project_name \
30 | --pretrained_weight $ckpt_path \
31 | --task_data_config_path data_provider/fewshot_new_task.yaml
32 |
33 |
--------------------------------------------------------------------------------
/scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct05.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_newdata
4 | exp_name=fewshot_newdata_prompt_tuning_pct05
5 | random_port=$((RANDOM % 9000 + 1000))
6 |
7 | # Path to the SSL pre-trained checkpoint
8 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth
9 |
10 | torchrun --nnodes 1 --master_port $random_port run.py \
11 | --is_training 1 \
12 | --fix_seed 2021 \
13 | --model_id $exp_name \
14 | --subsample_pct 0.05 \
15 | --model $model_name \
16 | --prompt_num 10 \
17 | --patch_len 16 \
18 | --stride 16 \
19 | --e_layers 3 \
20 | --d_model 128 \
21 | --des 'Exp' \
22 | --prompt_tune_epoch 10 \
23 | --train_epochs 0 \
24 | --lradj prompt_tuning \
25 | --learning_rate 1e-3 \
26 | --weight_decay 1e-4 \
27 | --dropout 0 \
28 | --acc_it 8 \
29 | --clip_grad 100 \
30 | --debug $wandb_mode \
31 | --project_name $project_name \
32 | --pretrained_weight $ckpt_path \
33 | --task_data_config_path data_provider/fewshot_new_task.yaml
--------------------------------------------------------------------------------
/scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct15.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_newdata
4 | exp_name=fewshot_newdata_prompt_tuning_pct15
5 | random_port=$((RANDOM % 9000 + 1000))
6 |
7 | # Path to the SSL pre-trained checkpoint
8 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth
9 |
10 | torchrun --nnodes 1 --master_port $random_port run.py \
11 | --is_training 1 \
12 | --fix_seed 2021 \
13 | --model_id $exp_name \
14 | --subsample_pct 0.15 \
15 | --model $model_name \
16 | --prompt_num 10 \
17 | --patch_len 16 \
18 | --stride 16 \
19 | --e_layers 3 \
20 | --d_model 128 \
21 | --des 'Exp' \
22 | --prompt_tune_epoch 10 \
23 | --train_epochs 0 \
24 | --lradj prompt_tuning \
25 | --learning_rate 1e-3 \
26 | --weight_decay 1e-4 \
27 | --dropout 0 \
28 | --acc_it 8 \
29 | --clip_grad 100 \
30 | --debug $wandb_mode \
31 | --project_name $project_name \
32 | --pretrained_weight $ckpt_path \
33 | --task_data_config_path data_provider/fewshot_new_task.yaml
--------------------------------------------------------------------------------
/scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct20.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | wandb_mode=online
3 | project_name=fewshot_newdata
4 | exp_name=fewshot_newdata_prompt_tuning_pct20
5 | random_port=$((RANDOM % 9000 + 1000))
6 |
7 | # Path to the SSL pre-trained checkpoint
8 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth
9 |
10 | torchrun --nnodes 1 --master_port $random_port run.py \
11 | --is_training 1 \
12 | --fix_seed 2021 \
13 | --model_id $exp_name \
14 | --subsample_pct 0.20 \
15 | --model $model_name \
16 | --prompt_num 10 \
17 | --patch_len 16 \
18 | --stride 16 \
19 | --e_layers 3 \
20 | --d_model 128 \
21 | --des 'Exp' \
22 | --prompt_tune_epoch 10 \
23 | --train_epochs 0 \
24 | --lradj prompt_tuning \
25 | --learning_rate 1e-3 \
26 | --weight_decay 1e-4 \
27 | --dropout 0 \
28 | --acc_it 8 \
29 | --clip_grad 100 \
30 | --debug $wandb_mode \
31 | --project_name $project_name \
32 | --pretrained_weight $ckpt_path \
33 | --task_data_config_path data_provider/fewshot_new_task.yaml
--------------------------------------------------------------------------------
/scripts/pretrain_prompt_learning/UniTS_pretrain_x128.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | exp_name=UniTS_pretrain_x128
3 | wandb_mode=online
4 | ptune_name=prompt_tuning
5 |
6 | d_model=128
7 |
8 | random_port=$((RANDOM % 9000 + 1000))
9 |
10 | # Pretrain
11 | torchrun --nnodes 1 --nproc-per-node 2 --master_port $random_port run_pretrain.py \
12 | --is_training 1 \
13 | --model_id $exp_name \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model $d_model \
20 | --des 'Exp' \
21 | --acc_it 128 \
22 | --batch_size 32 \
23 | --learning_rate 5e-5 \
24 | --min_lr 1e-4 \
25 | --weight_decay 5e-6 \
26 | --train_epochs 10 \
27 | --warmup_epochs 0 \
28 | --min_keep_ratio 0.5 \
29 | --right_prob 0.5 \
30 | --min_mask_ratio 0.7 \
31 | --max_mask_ratio 0.8 \
32 | --debug $wandb_mode \
33 | --task_data_config_path data_provider/multi_task_pretrain.yaml
34 |
35 | # Prompt tuning
36 | torchrun --nnodes 1 --master_port $random_port run.py \
37 | --is_training 1 \
38 | --model_id $exp_name \
39 | --model $model_name \
40 | --lradj prompt_tuning \
41 | --prompt_num 10 \
42 | --patch_len 16 \
43 | --stride 16 \
44 | --e_layers 3 \
45 | --d_model $d_model \
46 | --des 'Exp' \
47 | --itr 1 \
48 | --learning_rate 3e-3 \
49 | --weight_decay 0 \
50 | --prompt_tune_epoch 2 \
51 | --train_epochs 0 \
52 | --acc_it 32 \
53 | --debug $wandb_mode \
54 | --project_name $ptune_name \
55 | --clip_grad 100 \
56 | --pretrained_weight auto \
57 | --task_data_config_path data_provider/multi_task.yaml
58 |
--------------------------------------------------------------------------------
/scripts/pretrain_prompt_learning/UniTS_pretrain_x32.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | exp_name=UniTS_pretrain_x32
3 | wandb_mode=online
4 | ptune_name=prompt_tuning
5 |
6 | d_model=32
7 |
8 | random_port=$((RANDOM % 9000 + 1000))
9 |
10 | # Pretrain
11 | torchrun --nnodes 1 --nproc-per-node 1 --master_port $random_port run_pretrain.py \
12 | --is_training 1 \
13 | --model_id $exp_name \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model $d_model \
20 | --des 'Exp' \
21 | --acc_it 128 \
22 | --batch_size 32 \
23 | --learning_rate 5e-5 \
24 | --min_lr 1e-4 \
25 | --weight_decay 5e-6 \
26 | --train_epochs 10 \
27 | --warmup_epochs 0 \
28 | --min_keep_ratio 0.5 \
29 | --right_prob 0.5 \
30 | --min_mask_ratio 0.7 \
31 | --max_mask_ratio 0.8 \
32 | --debug $wandb_mode \
33 | --task_data_config_path data_provider/multi_task_pretrain.yaml
34 |
35 | # Prompt tuning
36 | torchrun --nnodes 1 --master_port $random_port run.py \
37 | --is_training 1 \
38 | --model_id $exp_name \
39 | --model $model_name \
40 | --lradj prompt_tuning \
41 | --prompt_num 10 \
42 | --patch_len 16 \
43 | --stride 16 \
44 | --e_layers 3 \
45 | --d_model $d_model \
46 | --des 'Exp' \
47 | --itr 1 \
48 | --learning_rate 1e-3 \
49 | --weight_decay 0 \
50 | --prompt_tune_epoch 2 \
51 | --train_epochs 0 \
52 | --acc_it 32 \
53 | --debug $wandb_mode \
54 | --project_name $ptune_name \
55 | --clip_grad 100 \
56 | --pretrained_weight auto \
57 | --task_data_config_path data_provider/multi_task.yaml
58 |
--------------------------------------------------------------------------------
/scripts/pretrain_prompt_learning/UniTS_pretrain_x64.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | exp_name=UniTS_pretrain_x64
3 | wandb_mode=online
4 | ptune_name=prompt_tuning
5 |
6 | d_model=64
7 |
8 | random_port=$((RANDOM % 9000 + 1000))
9 |
10 | # Pretrain
11 | torchrun --nnodes 1 --nproc-per-node 2 --master_port $random_port run_pretrain.py \
12 | --is_training 1 \
13 | --model_id $exp_name \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model $d_model \
20 | --des 'Exp' \
21 | --acc_it 128 \
22 | --batch_size 32 \
23 | --learning_rate 5e-5 \
24 | --min_lr 1e-4 \
25 | --weight_decay 5e-6 \
26 | --train_epochs 10 \
27 | --warmup_epochs 0 \
28 | --min_keep_ratio 0.5 \
29 | --right_prob 0.5 \
30 | --min_mask_ratio 0.7 \
31 | --max_mask_ratio 0.8 \
32 | --debug $wandb_mode \
33 | --task_data_config_path data_provider/multi_task_pretrain.yaml
34 |
35 | # Prompt tuning
36 | torchrun --nnodes 1 --master_port $random_port run.py \
37 | --is_training 1 \
38 | --model_id $exp_name \
39 | --model $model_name \
40 | --lradj prompt_tuning \
41 | --prompt_num 10 \
42 | --patch_len 16 \
43 | --stride 16 \
44 | --e_layers 3 \
45 | --d_model $d_model \
46 | --des 'Exp' \
47 | --itr 1 \
48 | --learning_rate 3e-3 \
49 | --weight_decay 0 \
50 | --prompt_tune_epoch 2 \
51 | --train_epochs 0 \
52 | --acc_it 32 \
53 | --debug $wandb_mode \
54 | --project_name $ptune_name \
55 | --clip_grad 100 \
56 | --pretrained_weight auto \
57 | --task_data_config_path data_provider/multi_task.yaml
--------------------------------------------------------------------------------
/scripts/supervised_learning/UniTS_supervised.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS
2 | exp_name=UniTS_supervised_x64
3 | wandb_mode=online
4 | project_name=supervised_learning
5 |
6 | random_port=$((RANDOM % 9000 + 1000))
7 |
8 | # Supervised learning
9 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \
10 | --is_training 1 \
11 | --model_id $exp_name \
12 | --model $model_name \
13 | --lradj supervised \
14 | --prompt_num 10 \
15 | --patch_len 16 \
16 | --stride 16 \
17 | --e_layers 3 \
18 | --d_model 64 \
19 | --des 'Exp' \
20 | --learning_rate 1e-4 \
21 | --weight_decay 5e-6 \
22 | --train_epochs 5 \
23 | --batch_size 32 \
24 | --acc_it 32 \
25 | --debug $wandb_mode \
26 | --project_name $project_name \
27 | --clip_grad 100 \
28 | --task_data_config_path data_provider/multi_task.yaml
--------------------------------------------------------------------------------
/scripts/zero_shot/UniTS_forecast_new_length_unify.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | model_name=UniTS
4 | exp_name=UniTS_pretrain_x128
5 | wandb_mode=disabled
6 | ptune_name=prompt_tuning
7 |
8 | d_model=128
9 |
10 | random_port=$((RANDOM % 9000 + 1000))
11 |
12 | # Get the pretrained model
13 | # cripts/pretrain_prompt_learning/UniTS_pretrain_x128.sh
14 | ckpt_path=pretrain_ckpt.pth
15 |
16 |
17 | offset=384
18 | torchrun --nnodes 1 --master_port $random_port run.py \
19 | --is_training 0 \
20 | --zero_shot_forecasting_new_length unify \
21 | --max_offset 384 \
22 | --offset $offset \
23 | --model_id $exp_name \
24 | --model $model_name \
25 | --lradj prompt_tuning \
26 | --prompt_num 10 \
27 | --patch_len 16 \
28 | --stride 16 \
29 | --e_layers 3 \
30 | --d_model $d_model \
31 | --des 'Exp' \
32 | --itr 1 \
33 | --debug $wandb_mode \
34 | --project_name $ptune_name \
35 | --pretrained_weight $ckpt_path \
36 | --task_data_config_path data_provider/multitask_zero_shot_new_length.yaml
--------------------------------------------------------------------------------
/scripts/zero_shot/UniTS_zeroshot_newdata.sh:
--------------------------------------------------------------------------------
1 | model_name=UniTS_zeroshot
2 | exp_name=UniTS_zeroshot_pretrain_x64
3 | wandb_mode=online
4 | ptune_name=zeroshot_newdata
5 |
6 | d_model=64
7 |
8 | random_port=$((RANDOM % 9000 + 1000))
9 |
10 | # Pretrain of zero-shot version of UniTS
11 | torchrun --nnodes 1 --nproc-per-node 2 --master_port $random_port run_pretrain.py \
12 | --is_training 1 \
13 | --model_id $exp_name \
14 | --model $model_name \
15 | --prompt_num 10 \
16 | --patch_len 16 \
17 | --stride 16 \
18 | --e_layers 3 \
19 | --d_model $d_model \
20 | --des 'Exp' \
21 | --acc_it 128 \
22 | --batch_size 32 \
23 | --learning_rate 5e-5 \
24 | --min_lr 1e-4 \
25 | --weight_decay 5e-6 \
26 | --train_epochs 10 \
27 | --warmup_epochs 0 \
28 | --min_keep_ratio 0.5 \
29 | --right_prob 0.5 \
30 | --min_mask_ratio 0.7 \
31 | --max_mask_ratio 0.8 \
32 | --debug $wandb_mode \
33 | --task_data_config_path data_provider/multi_task_pretrain.yaml
34 |
35 | # Zero-shot test on new forecasting datasets
36 | # Note: The inference in this code test all samples of the dataset,
37 | # which is not the same as the original paper that only test 1 sample for each dataset.
38 | torchrun --nnodes 1 --master_port $random_port run.py \
39 | --is_training 0 \
40 | --model_id $exp_name \
41 | --model $model_name \
42 | --prompt_num 10 \
43 | --patch_len 16 \
44 | --stride 16 \
45 | --e_layers 3 \
46 | --d_model $d_model \
47 | --des 'Exp' \
48 | --debug $wandb_mode \
49 | --project_name $ptune_name \
50 | --pretrained_weight auto \
51 | --task_data_config_path data_provider/zeroshot_task.yaml
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BalancedDataLoaderIterator:
5 | def __init__(self, dataloaders):
6 | self.dataloaders = dataloaders
7 |
8 | self.num_dataloaders = len(dataloaders)
9 |
10 | max_length = max(len(dataloader) for dataloader in dataloaders)
11 |
12 | length_list = [len(dataloader) for dataloader in dataloaders]
13 | print("data loader length:", length_list)
14 | print("max dataloader length:", max_length,
15 | "epoch iteration:", max_length * self.num_dataloaders)
16 | self.total_length = max_length * self.num_dataloaders
17 | self.current_iteration = 0
18 | self.probabilities = torch.ones(
19 | self.num_dataloaders, dtype=torch.float) / self.num_dataloaders
20 |
21 | def __iter__(self):
22 | self.iterators = [iter(dataloader) for dataloader in self.dataloaders]
23 | self.current_iteration = 0
24 | return self
25 |
26 | def __next__(self):
27 | if self.current_iteration >= self.total_length:
28 | raise StopIteration
29 |
30 | chosen_index = torch.multinomial(self.probabilities, 1).item()
31 | try:
32 | sample = next(self.iterators[chosen_index])
33 | except StopIteration:
34 | self.iterators[chosen_index] = iter(self.dataloaders[chosen_index])
35 | sample = next(self.iterators[chosen_index])
36 |
37 | self.current_iteration += 1
38 | return sample, chosen_index
39 |
40 | def __len__(self):
41 | return self.total_length
42 |
43 | def generate_fake_samples_for_batch(self, dataloader_id, batch_size):
44 | if dataloader_id >= len(self.dataloaders) or dataloader_id < 0:
45 | raise ValueError("Invalid dataloader ID")
46 |
47 | dataloader = self.dataloaders[dataloader_id]
48 | iterator = iter(dataloader)
49 |
50 | try:
51 | sample_batch = next(iterator)
52 | fake_samples = []
53 |
54 | for sample in sample_batch:
55 | if isinstance(sample, torch.Tensor):
56 | fake_sample = torch.zeros(
57 | [batch_size] + list(sample.shape)[1:])
58 | fake_samples.append(fake_sample)
59 | else:
60 | pass
61 |
62 | return fake_samples, dataloader_id
63 | except StopIteration:
64 | return None
65 |
--------------------------------------------------------------------------------
/utils/ddp.py:
--------------------------------------------------------------------------------
1 | import torch.distributed as dist
2 | import torch
3 |
4 |
5 | def is_dist_avail_and_initialized():
6 | if not dist.is_available():
7 | return False
8 | if not dist.is_initialized():
9 | return False
10 | return True
11 |
12 |
13 | def get_world_size():
14 | if not is_dist_avail_and_initialized():
15 | return 1
16 | return dist.get_world_size()
17 |
18 |
19 | def get_rank():
20 | if not is_dist_avail_and_initialized():
21 | return 0
22 | return dist.get_rank()
23 |
24 |
25 | def is_main_process():
26 | return get_rank() == 0
27 |
28 |
29 | def init_distributed_mode(args):
30 |
31 | dist.init_process_group(
32 | backend="nccl",
33 | )
34 | rank = dist.get_rank()
35 | torch.cuda.set_device(rank)
36 | torch.cuda.empty_cache()
37 | print(f"Start running basic DDP on rank {rank}.")
38 |
39 | dist.barrier()
40 | setup_for_distributed(rank == 0)
41 |
42 |
43 | def setup_for_distributed(is_master):
44 | """
45 | This function disables printing when not in master process
46 | """
47 | import builtins as __builtin__
48 | builtin_print = __builtin__.print
49 |
50 | def print(*args, **kwargs):
51 | force = kwargs.pop('force', False)
52 | if is_master or force:
53 | builtin_print(*args, **kwargs)
54 |
55 | __builtin__.print = print
56 |
57 | def gather_tensors_from_all_gpus(tensor_list, device_id, to_numpy=True):
58 | """
59 | Gather tensors from all GPUs in a DDP setup onto each GPU.
60 |
61 | Args:
62 | local_tensors (list of torch.Tensor): List of tensors on the local GPU.
63 |
64 | Returns:
65 | list of torch.Tensor: List of all tensors gathered from all GPUs, available on each GPU.
66 | """
67 | world_size = dist.get_world_size()
68 | tensor_list = [tensor.to(device_id).contiguous() for tensor in tensor_list]
69 | gathered_tensors = [[] for _ in range(len(tensor_list))]
70 |
71 | # Gathering tensors from all GPUs
72 | for tensor in tensor_list:
73 | # Each GPU will gather tensors from all other GPUs
74 | gathered_list = [torch.empty_like(tensor) for _ in range(world_size)]
75 | dist.all_gather(gathered_list, tensor)
76 | gathered_tensors.append(gathered_list)
77 | del tensor_list
78 | # Flattening the gathered list
79 | flattened_tensors = [
80 | tensor for sublist in gathered_tensors for tensor in sublist]
81 | del gathered_tensors
82 | if to_numpy:
83 | flattened_tensors_numpy = [tensor.cpu().numpy()
84 | for tensor in flattened_tensors]
85 | del flattened_tensors
86 |
87 | return flattened_tensors_numpy
88 | else:
89 | return flattened_tensors
90 |
--------------------------------------------------------------------------------
/utils/layer_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ELECTRA https://github.com/google-research/electra
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import json
13 |
14 |
15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16 | """
17 | Parameter groups for layer-wise lr decay
18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19 | """
20 | param_group_names = {}
21 | param_groups = {}
22 |
23 | num_layers = len(model.blocks) + 2
24 |
25 | layer_scales = list(layer_decay ** (num_layers - i)
26 | for i in range(num_layers + 1))
27 |
28 | for n, p in model.named_parameters():
29 | if not p.requires_grad:
30 | continue
31 |
32 | # no decay: all 1D parameters and model specific ones
33 | if p.ndim == 1 or n in no_weight_decay_list:
34 | g_decay = "no_decay"
35 | this_decay = 0.
36 | else:
37 | g_decay = "decay"
38 | this_decay = weight_decay
39 |
40 | layer_id = get_layer_id_for_model(n, num_layers)
41 | group_name = "layer_%d_%s" % (layer_id, g_decay)
42 |
43 | if group_name not in param_group_names:
44 | this_scale = layer_scales[layer_id]
45 |
46 | param_group_names[group_name] = {
47 | "lr_scale": this_scale,
48 | "weight_decay": this_decay,
49 | "params": [],
50 | }
51 | param_groups[group_name] = {
52 | "lr_scale": this_scale,
53 | "weight_decay": this_decay,
54 | "params": [],
55 | }
56 | print("name: %s, layer_id: %d, group_name: %s, lr_scale: %f, weight_decay: %f" % (
57 | n, layer_id, group_name, param_group_names[group_name]["lr_scale"], param_group_names[group_name]["weight_decay"]))
58 |
59 | param_group_names[group_name]["params"].append(n)
60 | param_groups[group_name]["params"].append(p)
61 |
62 | return list(param_groups.values())
63 |
64 |
65 | def get_layer_id_for_model(name, num_layers):
66 | """
67 | Assign a parameter with its layer id
68 | """
69 | if name in ['cls_token', 'patch_embeddings', 'position_embedding', 'prompt2forecat', 'prompt_tokens', 'mask_tokens', 'cls_tokens', 'category_tokens']:
70 | return 0
71 | elif name.startswith('input_encoders'):
72 | return 1
73 | elif name.startswith('blocks'):
74 | return int(name.split('.')[1]) + 2
75 | else:
76 | return num_layers
77 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | # Part of the file is from https://github.com/thuml/Time-Series-Library/blob/main/utils/losses.py
2 | """
3 | Loss functions for PyTorch.
4 | """
5 |
6 | import torch as t
7 | import torch.nn as nn
8 | import numpy as np
9 | import pdb
10 | import torch.nn as nn
11 |
12 |
13 | def divide_no_nan(a, b):
14 | """
15 | a/b where the resulted NaN or Inf are replaced by 0.
16 | """
17 | result = a / b
18 | result[result != result] = .0
19 | result[result == np.inf] = .0
20 | return result
21 |
22 |
23 | class mape_loss(nn.Module):
24 | def __init__(self):
25 | super(mape_loss, self).__init__()
26 |
27 | def forward(self, insample: t.Tensor, freq: int,
28 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
29 | """
30 | MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error
31 |
32 | :param forecast: Forecast values. Shape: batch, time
33 | :param target: Target values. Shape: batch, time
34 | :param mask: 0/1 mask. Shape: batch, time
35 | :return: Loss value
36 | """
37 | weights = divide_no_nan(mask, target)
38 | return t.mean(t.abs((forecast - target) * weights))
39 |
40 |
41 | class smape_loss(nn.Module):
42 | def __init__(self):
43 | super(smape_loss, self).__init__()
44 |
45 | def forward(self, insample: t.Tensor, freq: int,
46 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
47 | """
48 | sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993)
49 |
50 | :param forecast: Forecast values. Shape: batch, time
51 | :param target: Target values. Shape: batch, time
52 | :param mask: 0/1 mask. Shape: batch, time
53 | :return: Loss value
54 | """
55 | return 200 * t.mean(divide_no_nan(t.abs(forecast - target),
56 | t.abs(forecast.data) + t.abs(target.data)) * mask)
57 |
58 |
59 | class mase_loss(nn.Module):
60 | def __init__(self):
61 | super(mase_loss, self).__init__()
62 |
63 | def forward(self, insample: t.Tensor, freq: int,
64 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
65 | """
66 | MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf
67 |
68 | :param insample: Insample values. Shape: batch, time_i
69 | :param freq: Frequency value
70 | :param forecast: Forecast values. Shape: batch, time_o
71 | :param target: Target values. Shape: batch, time_o
72 | :param mask: 0/1 mask. Shape: batch, time_o
73 | :return: Loss value
74 | """
75 | masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1)
76 | masked_masep_inv = divide_no_nan(mask, masep[:, None])
77 | return t.mean(t.abs(target - forecast) * masked_masep_inv)
78 |
79 |
80 | class UnifiedMaskRecLoss(nn.Module):
81 | def __init__(self):
82 | super().__init__()
83 |
84 | def forward_mim_loss(self, target, pred, pad_mask):
85 | loss = (pred - target) ** 2
86 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
87 |
88 | combined_mask = pad_mask.bool()
89 |
90 | loss = (loss * combined_mask).sum() / combined_mask.sum()
91 | return loss
92 |
93 | def forward(self, outputs, target, pad_mask):
94 | student_cls, student_fore, _ = outputs
95 |
96 | mask_loss = self.forward_mim_loss(target, student_fore, pad_mask)
97 |
98 | if student_cls is not None:
99 | cls_loss = self.forward_mim_loss(target, student_cls, pad_mask)
100 | else:
101 | cls_loss = 0.0 * mask_loss
102 |
103 | total_loss = dict(cls_loss=cls_loss,
104 | mask_loss=mask_loss, loss=mask_loss+cls_loss)
105 | return total_loss
106 |
--------------------------------------------------------------------------------
/utils/m4_summary.py:
--------------------------------------------------------------------------------
1 | # This file is removed due to LICENSE file constraints.
2 | # You can copy the m4_summary.py from https://github.com/thuml/Time-Series-Library/blob/main/utils/m4_summary.py
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TriangularCausalMask():
5 | def __init__(self, B, L, device="cpu"):
6 | mask_shape = [B, 1, L, L]
7 | with torch.no_grad():
8 | self._mask = torch.triu(torch.ones(
9 | mask_shape, dtype=torch.bool), diagonal=1).to(device)
10 |
11 | @property
12 | def mask(self):
13 | return self._mask
14 |
15 |
16 | class ProbMask():
17 | def __init__(self, B, H, L, index, scores, device="cpu"):
18 | _mask = torch.ones(
19 | L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
20 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
21 | indicator = _mask_ex[torch.arange(B)[:, None, None],
22 | torch.arange(H)[None, :, None],
23 | index, :].to(device)
24 | self._mask = indicator.view(scores.shape).to(device)
25 |
26 | @property
27 | def mask(self):
28 | return self._mask
29 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def RSE(pred, true):
5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6 |
7 |
8 | def CORR(pred, true):
9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10 | d = np.sqrt(((true - true.mean(0)) ** 2 *
11 | (pred - pred.mean(0)) ** 2).sum(0))
12 | return (u / d).mean(-1)
13 |
14 |
15 | def MAE(pred, true):
16 | return np.mean(np.abs(pred - true))
17 |
18 |
19 | def MSE(pred, true):
20 | return np.mean((pred - true) ** 2)
21 |
22 |
23 | def RMSE(pred, true):
24 | return np.sqrt(MSE(pred, true))
25 |
26 |
27 | def MAPE(pred, true):
28 | return np.mean(np.abs((pred - true) / true))
29 |
30 |
31 | def MSPE(pred, true):
32 | return np.mean(np.square((pred - true) / true))
33 |
34 |
35 | def metric(pred, true):
36 | mae = MAE(pred, true)
37 | mse = MSE(pred, true)
38 | rmse = RMSE(pred, true)
39 | mape = MAPE(pred, true)
40 | mspe = MSPE(pred, true)
41 |
42 | return mae, mse, rmse, mape, mspe
43 |
--------------------------------------------------------------------------------
/utils/timefeatures.py:
--------------------------------------------------------------------------------
1 | # From: gluonts/src/gluonts/time_feature/_base.py
2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License").
5 | # You may not use this file except in compliance with the License.
6 | # A copy of the License is located at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # or in the "license" file accompanying this file. This file is distributed
11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12 | # express or implied. See the License for the specific language governing
13 | # permissions and limitations under the License.
14 |
15 | from typing import List
16 |
17 | import numpy as np
18 | import pandas as pd
19 | from pandas.tseries import offsets
20 | from pandas.tseries.frequencies import to_offset
21 |
22 |
23 | class TimeFeature:
24 | def __init__(self):
25 | pass
26 |
27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
28 | pass
29 |
30 | def __repr__(self):
31 | return self.__class__.__name__ + "()"
32 |
33 |
34 | class SecondOfMinute(TimeFeature):
35 | """Minute of hour encoded as value between [-0.5, 0.5]"""
36 |
37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
38 | return index.second / 59.0 - 0.5
39 |
40 |
41 | class MinuteOfHour(TimeFeature):
42 | """Minute of hour encoded as value between [-0.5, 0.5]"""
43 |
44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
45 | return index.minute / 59.0 - 0.5
46 |
47 |
48 | class HourOfDay(TimeFeature):
49 | """Hour of day encoded as value between [-0.5, 0.5]"""
50 |
51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
52 | return index.hour / 23.0 - 0.5
53 |
54 |
55 | class DayOfWeek(TimeFeature):
56 | """Hour of day encoded as value between [-0.5, 0.5]"""
57 |
58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
59 | return index.dayofweek / 6.0 - 0.5
60 |
61 |
62 | class DayOfMonth(TimeFeature):
63 | """Day of month encoded as value between [-0.5, 0.5]"""
64 |
65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
66 | return (index.day - 1) / 30.0 - 0.5
67 |
68 |
69 | class DayOfYear(TimeFeature):
70 | """Day of year encoded as value between [-0.5, 0.5]"""
71 |
72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
73 | return (index.dayofyear - 1) / 365.0 - 0.5
74 |
75 |
76 | class MonthOfYear(TimeFeature):
77 | """Month of year encoded as value between [-0.5, 0.5]"""
78 |
79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
80 | return (index.month - 1) / 11.0 - 0.5
81 |
82 |
83 | class WeekOfYear(TimeFeature):
84 | """Week of year encoded as value between [-0.5, 0.5]"""
85 |
86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
87 | return (index.isocalendar().week - 1) / 52.0 - 0.5
88 |
89 |
90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
91 | """
92 | Returns a list of time features that will be appropriate for the given frequency string.
93 | Parameters
94 | ----------
95 | freq_str
96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
97 | """
98 |
99 | features_by_offsets = {
100 | offsets.YearEnd: [],
101 | offsets.QuarterEnd: [MonthOfYear],
102 | offsets.MonthEnd: [MonthOfYear],
103 | offsets.Week: [DayOfMonth, WeekOfYear],
104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
107 | offsets.Minute: [
108 | MinuteOfHour,
109 | HourOfDay,
110 | DayOfWeek,
111 | DayOfMonth,
112 | DayOfYear,
113 | ],
114 | offsets.Second: [
115 | SecondOfMinute,
116 | MinuteOfHour,
117 | HourOfDay,
118 | DayOfWeek,
119 | DayOfMonth,
120 | DayOfYear,
121 | ],
122 | }
123 |
124 | offset = to_offset(freq_str)
125 |
126 | for offset_type, feature_classes in features_by_offsets.items():
127 | if isinstance(offset, offset_type):
128 | return [cls() for cls in feature_classes]
129 |
130 | supported_freq_msg = f"""
131 | Unsupported frequency {freq_str}
132 | The following frequencies are supported:
133 | Y - yearly
134 | alias: A
135 | M - monthly
136 | W - weekly
137 | D - daily
138 | B - business days
139 | H - hourly
140 | T - minutely
141 | alias: min
142 | S - secondly
143 | """
144 | raise RuntimeError(supported_freq_msg)
145 |
146 |
147 | def time_features(dates, freq='h'):
148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
149 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from torch import inf
5 |
6 | plt.switch_backend('agg')
7 |
8 |
9 | def adjust_learning_rate(optimizer, epoch, base_lr, args):
10 | assert args.prompt_tune_epoch >= 0, "args.prompt_tune_epoch >=0!"
11 | if args.lradj == 'prompt_tuning':
12 | if epoch < args.prompt_tune_epoch:
13 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch) // 1))}
14 | elif epoch == args.prompt_tune_epoch:
15 | lr_adjust = {epoch: base_lr}
16 | else:
17 | lr_adjust = {epoch: args.learning_rate *
18 | (0.5 ** (((epoch-args.prompt_tune_epoch) - 1) // 1))}
19 | elif args.lradj == 'supervised':
20 | if epoch <= args.prompt_tune_epoch:
21 | lr_adjust = {epoch: base_lr}
22 | else:
23 | lr_adjust = {epoch: base_lr / 5 *
24 | (0.5 ** (((epoch-args.prompt_tune_epoch)) // 1))}
25 | elif args.lradj == 'finetune_anl':
26 | k = 1
27 | lr_adjust = {epoch: base_lr / (2 ** ((epoch) // k))}
28 |
29 | if epoch in lr_adjust.keys():
30 | lr = lr_adjust[epoch]
31 | for param_group in optimizer.param_groups:
32 | if "lr_scale" in param_group:
33 | param_group["lr"] = lr * param_group["lr_scale"]
34 | else:
35 | param_group["lr"] = lr
36 | print('Epoch {}: Updating learning rate to {}'.format(epoch+1, lr))
37 |
38 |
39 | class dotdict(dict):
40 | """dot.notation access to dictionary attributes"""
41 | __getattr__ = dict.get
42 | __setattr__ = dict.__setitem__
43 | __delattr__ = dict.__delitem__
44 |
45 |
46 | class StandardScaler():
47 | def __init__(self, mean, std):
48 | self.mean = mean
49 | self.std = std
50 |
51 | def transform(self, data):
52 | return (data - self.mean) / self.std
53 |
54 | def inverse_transform(self, data):
55 | return (data * self.std) + self.mean
56 |
57 |
58 | def visual(true, preds=None, name='./pic/test.pdf'):
59 | """
60 | Results visualization
61 | """
62 | plt.figure()
63 | plt.plot(true, label='GroundTruth', linewidth=2)
64 | if preds is not None:
65 | plt.plot(preds, label='Prediction', linewidth=2)
66 | plt.legend()
67 | plt.savefig(name, bbox_inches='tight')
68 |
69 |
70 | def adjustment(gt, pred):
71 | anomaly_state = False
72 | for i in range(len(gt)):
73 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
74 | anomaly_state = True
75 | for j in range(i, 0, -1):
76 | if gt[j] == 0:
77 | break
78 | else:
79 | if pred[j] == 0:
80 | pred[j] = 1
81 | for j in range(i, len(gt)):
82 | if gt[j] == 0:
83 | break
84 | else:
85 | if pred[j] == 0:
86 | pred[j] = 1
87 | elif gt[i] == 0:
88 | anomaly_state = False
89 | if anomaly_state:
90 | pred[i] = 1
91 | return gt, pred
92 |
93 |
94 | def cal_accuracy(y_pred, y_true):
95 | return np.mean(y_pred == y_true)
96 |
97 |
98 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
99 | warmup_schedule = np.array([])
100 | warmup_iters = warmup_epochs * niter_per_ep
101 | if warmup_epochs > 0:
102 | warmup_schedule = np.linspace(
103 | start_warmup_value, base_value, warmup_iters)
104 |
105 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
106 | schedule = final_value + 0.5 * \
107 | (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
108 |
109 | schedule = np.concatenate((warmup_schedule, schedule))
110 | assert len(schedule) == epochs * niter_per_ep
111 | return schedule
112 |
113 |
114 | class NativeScalerWithGradNormCount:
115 | # https://github.com/facebookresearch/mae/blob/main/util/misc.py
116 | state_dict_key = "amp_scaler"
117 |
118 | def __init__(self):
119 | self._scaler = torch.cuda.amp.GradScaler()
120 |
121 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
122 | self._scaler.scale(loss).backward(create_graph=create_graph)
123 | if update_grad:
124 | if clip_grad is not None:
125 | assert parameters is not None
126 | # unscale the gradients of optimizer's assigned params in-place
127 | self._scaler.unscale_(optimizer)
128 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
129 | else:
130 | self._scaler.unscale_(optimizer)
131 | norm = get_grad_norm_(parameters)
132 | self._scaler.step(optimizer)
133 | self._scaler.update()
134 | else:
135 | norm = None
136 | return norm
137 |
138 | def state_dict(self):
139 | return self._scaler.state_dict()
140 |
141 | def load_state_dict(self, state_dict):
142 | self._scaler.load_state_dict(state_dict)
143 |
144 |
145 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
146 | if isinstance(parameters, torch.Tensor):
147 | parameters = [parameters]
148 | parameters = [p for p in parameters if p.grad is not None]
149 | norm_type = float(norm_type)
150 | if len(parameters) == 0:
151 | return torch.tensor(0.)
152 | device = parameters[0].grad.device
153 | if norm_type == inf:
154 | total_norm = max(p.grad.detach().abs().max().to(device)
155 | for p in parameters)
156 | else:
157 | total_norm = torch.norm(torch.stack([torch.norm(
158 | p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
159 | return total_norm
160 |
161 |
162 | def check_cuda_memory():
163 | """
164 | Check and print the current GPU memory usage in PyTorch.
165 | """
166 | if torch.cuda.is_available():
167 | current_device = torch.cuda.current_device()
168 | gpu_name = torch.cuda.get_device_name(current_device)
169 | total_memory = torch.cuda.get_device_properties(
170 | current_device).total_memory
171 | allocated_memory = torch.cuda.memory_allocated(current_device)
172 | cached_memory = torch.cuda.memory_reserved(current_device)
173 |
174 | print(f"GPU: {gpu_name}")
175 | print(f"Total Memory: {total_memory / 1e9:.5f} GB")
176 | print(f"Allocated Memory: {allocated_memory / 1e9:.5f} GB")
177 | print(f"Cached Memory: {cached_memory / 1e9:.5f} GB")
178 | else:
179 | print("CUDA is not available.")
180 |
--------------------------------------------------------------------------------