├── .gitignore ├── LICENSE ├── README.md ├── Tutorial.md ├── data_provider ├── __init__.py ├── anomaly_detection.yaml ├── data_factory.py ├── data_loader.py ├── fewshot_new_task.yaml ├── imputation.yaml ├── m4.py ├── multi_task.yaml ├── multi_task_pretrain.yaml ├── multitask_zero_shot_new_length.yaml ├── uea.py └── zeroshot_task.yaml ├── download_data_all.sh ├── exp ├── __init__.py ├── exp_pretrain.py └── exp_sup.py ├── models ├── UniTS.py ├── UniTS_zeroshot.py └── __init__.py ├── requirements.txt ├── run.py ├── run_pretrain.py ├── scripts ├── few_shot_anomaly_detection │ ├── UniTS_finetune_few_shot_anomaly_detection.sh │ └── UniTS_prompt_tuning_few_shot_anomaly_detection.sh ├── few_shot_imputation │ ├── UniTS_finetune_few_shot_imputation_mask025.sh │ ├── UniTS_finetune_few_shot_imputation_mask050.sh │ ├── UniTS_prompt_tuning_few_shot_imputation_mask025.sh │ └── UniTS_prompt_tuning_few_shot_imputation_mask050.sh ├── few_shot_newdata │ ├── UniTS_finetune_few_shot_newdata_pct05.sh │ ├── UniTS_finetune_few_shot_newdata_pct15.sh │ ├── UniTS_finetune_few_shot_newdata_pct20.sh │ ├── UniTS_prompt_tuning_few_shot_newdata_pct05.sh │ ├── UniTS_prompt_tuning_few_shot_newdata_pct15.sh │ └── UniTS_prompt_tuning_few_shot_newdata_pct20.sh ├── pretrain_prompt_learning │ ├── UniTS_pretrain_x128.sh │ ├── UniTS_pretrain_x32.sh │ └── UniTS_pretrain_x64.sh ├── supervised_learning │ └── UniTS_supervised.sh └── zero_shot │ ├── UniTS_forecast_new_length_unify.sh │ └── UniTS_zeroshot_newdata.sh └── utils ├── __init__.py ├── dataloader.py ├── ddp.py ├── layer_decay.py ├── losses.py ├── m4_summary.py ├── masking.py ├── metrics.py ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /scripts/long_term_forecast/Traffic_script/PatchTST1.sh 131 | /backups/ 132 | /result.xlsx 133 | /~$result.xlsx 134 | /Time-Series-Library.zip 135 | /temp.sh 136 | 137 | .idea 138 | /tv_result.xlsx 139 | /test.py 140 | /m4_results/ 141 | /test_results/ 142 | /PatchTST_results.xlsx 143 | /seq_len_long_term_forecast/ 144 | /progress.xlsx 145 | /scripts/short_term_forecast/PatchTST_M4.sh 146 | /run_tv.py 147 | 148 | /scripts/long_term_forecast/ETT_tv_script/ 149 | /dataset/ 150 | data_factory_all.py 151 | data_loader_all.py 152 | /scripts/short_term_forecast/tv_script/ 153 | /exp/exp_short_term_forecasting_tv.py 154 | /exp/exp_long_term_forecasting_tv.py 155 | /timesnetv2.xlsx 156 | /scripts/anomaly_detection/tmp/ 157 | /scripts/imputation/tmp/ 158 | /utils/self_tools.py 159 | 160 | checkpoints 161 | logs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unified Time Series Model 2 | 3 | [**Project Page**](https://zitniklab.hms.harvard.edu/projects/UniTS/) | [**Paper link**](https://arxiv.org/pdf/2403.00131.pdf) **(Neurips 2024)** 4 | 5 | UniTS is a unified time series model that can process various tasks across multiple domains with shared parameters and does not have any task-specific modules. 6 | 7 | Authors: [Shanghua Gao](https://shgao.site/) [Teddy Koker](https://teddykoker.com) [Owen Queen](https://owencqueen.github.io/) [Thomas Hartvigsen](https://www.tomhartvigsen.com/) [Theodoros Tsiligkaridis](https://sites.google.com/view/theo-t) [Marinka Zitnik](https://zitniklab.hms.harvard.edu/) 8 | 9 | ## Overview 10 | Foundation models, especially LLMs, are profoundly transforming deep learning. Instead of training many task-specific models, we can adapt a single pretrained model to many tasks via few-shot prompting or fine-tuning. However, current foundation models apply to sequence data but not to time series, which present unique challenges due to the inherent diverse and multi-domain time series datasets, diverging task specifications across forecasting, classification and other types of tasks, and the apparent need for task-specialized models. 11 | 12 | We developed UniTS, a unified time series model that supports a universal task specification, accommodating classification, forecasting, imputation, and anomaly detection tasks. This is achieved through a novel unified network backbone, which incorporates sequence and variable attention along with a dynamic linear operator and is trained as a unified model. 13 | 14 | Across 38 multi-domain datasets, UniTS demonstrates superior performance compared to task-specific models and repurposed natural language-based LLMs. UniTS exhibits remarkable zero-shot, few-shot, and prompt learning capabilities when evaluated on new data domains and tasks. 15 | 16 |

17 | UniTS-1 18 |

19 | 20 | ## Setups 21 | 22 | ### 1. Requirements 23 | Install Pytorch2.0+ and the required packages. 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### 2. Prepare data 29 | ``` 30 | bash download_data_all.sh 31 | ``` 32 | Datasets configs for different multi-task settings are shown in `.ymal` files of the `data_provider` folder. 33 | 34 | By default, all experiments follow the multi-task setting where one UniTS model is jointly trained on mulitple datasets. 35 | 36 | ### 3. Train and evaluate model 37 | 38 | #### 1. Multi-task learning on forecasting and classification tasks: 39 | 40 | - Pretraining + Prompt learning 41 | ``` 42 | bash ./scripts/pretrain_prompt_learning/UniTS_pretrain_x128.sh 43 | ``` 44 | 45 | - Supervised learning 46 | ``` 47 | bash ./scripts/supervised_learning/UniTS_supervised.sh 48 | ``` 49 | 50 | #### 2. Few-shot transfer learning on new forecasting and classification tasks: 51 | 52 | **Note: Please follow the instruction in following training scripts to get the pretrained ckpt first.** 53 | 54 | - Finetuning 55 | ``` 56 | # please set the pretrianed model path in the script. 57 | bash ./scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct20.sh 58 | ``` 59 | 60 | - Prompt tuning 61 | ``` 62 | # please set the pretrianed model path in the script. 63 | bash ./scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct20.sh 64 | ``` 65 | 66 | #### 3. Few-shot transfer learning on anomaly detection tasks: 67 | - Finetuning 68 | ``` 69 | # please set the pretrianed model path in the script. 70 | bash ./scripts/few_shot_anomaly_detection/UniTS_finetune_few_shot_anomaly_detection.sh 71 | ``` 72 | - Prompt tuning 73 | ``` 74 | # please set the pretrianed model path in the script. 75 | bash ./scripts/few_shot_anomaly_detection/UniTS_prompt_tuning_few_shot_anomaly_detection.sh 76 | ``` 77 | 78 | #### 4. Few-shot transfer learning on imputation tasks: 79 | - Finetuning 80 | ``` 81 | # please set the pretrianed model path in the script. 82 | bash ./scripts/few_shot_imputation/UniTS_finetune_few_shot_imputation_mask050.sh 83 | ``` 84 | 85 | - Prompt tuning 86 | ``` 87 | # please set the pretrianed model path in the script. 88 | bash ./scripts/few_shot_imputation/UniTS_prompt_tuning_few_shot_imputation_mask050.sh 89 | ``` 90 | 91 | #### 5. Zero-shot learning on new forecasting length: 92 | ``` 93 | # please set the pretrianed model path in the script. 94 | bash ./scripts/zero_shot/UniTS_forecast_new_length_unify.sh 95 | ``` 96 | 97 | #### 6. Zero-shot learning on new forecasting datasets: 98 | ``` 99 | # A special verison of UniTS with shared prompt/mask tokens needs to be trained for this setting. 100 | bash ./scripts/zero_shot/UniTS_zeroshot_newdata.sh 101 | ``` 102 | 103 | ## Use UniTS on your own data. 104 | UniTS is a highly flexible unified time series model, supporting tasks such as forecasting, classification, imputation, and anomaly detection with a single shared model and shared weights. We provide a [Tutorial](Tutorial.md) to assist you in using your own data with UniTS. 105 | 106 | ## Pretrained weights 107 | We provide the pretrained weights for models mentioned above in [checkpoints](https://github.com/mims-harvard/UniTS/releases/tag/ckpt). 108 | 109 | ## Citation 110 | 111 | ``` 112 | @article{gao2024building, 113 | title={UniTS: Building a Unified Time Series Model}, 114 | author={Gao, Shanghua and Koker, Teddy and Queen, Owen and Hartvigsen, Thomas and Tsiligkaridis, Theodoros and Zitnik, Marinka}, 115 | journal={arXiv}, 116 | url={https://arxiv.org/pdf/2403.00131.pdf}, 117 | year={2024} 118 | } 119 | ``` 120 | 121 | ## Acknowledgement 122 | This codebase is built based on the [Time-Series-Library](https://github.com/thuml/Time-Series-Library). Thanks! 123 | 124 | ## Disclaimer 125 | 126 | DISTRIBUTION STATEMENT: Approved for public release. Distribution is unlimited. 127 | 128 | This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering. 129 | 130 | © 2024 Massachusetts Institute of Technology. 131 | 132 | Subject to FAR52.227-11 Patent Rights - Ownership by the contractor (May 2014) 133 | 134 | The software/firmware is provided to you on an As-Is basis 135 | 136 | Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other than as specifically authorized by the U.S. Government may violate any copyrights that exist in this work. 137 | -------------------------------------------------------------------------------- /Tutorial.md: -------------------------------------------------------------------------------- 1 | # Quick start for using UniTS on your own data. 2 | 3 | ## Classficiation with your own data. 4 | 5 | We use a classification task as an example. The primary difference for other tasks lies in the data formats. You can follow the provided dataset as a guide to adapt your own data. 6 | 7 | ### 1. Prepare data 8 | 9 | We support common data formats of time series datasets. 10 | 11 | You can follow the [dataset format guide](https://www.aeon-toolkit.org/en/latest/examples/datasets/data_loading.html) to transfer your dataset into `.ts` format dataset. 12 | 13 | The dataset should contain `newdata_TRAIN.ts` and `newdata_TEST.ts` files. 14 | 15 | ### 2. Define the dataset config file 16 | 17 | To support multiple datasets, our code base uses the `data_set.yaml` to keep the dataset information. 18 | Examples can be found in `data_provider` folder. 19 | 20 | Here is an example for classification dataset. You can add multiple dataset config in one config file if you want to make UniTS support multiple datasets. 21 | ```yaml 22 | task_dataset: 23 | CLS_ECG5000: # the dataset and task name 24 | task_name: classification # the type of task 25 | dataset: ECG5000 # the name of the dataset 26 | data: UEA # the data type of the dataset, use UEA if you use the '.ts' file 27 | embed: timeF # the embedding method used 28 | root_path: ../dataset/UCR/ECG5000 # the root path of the dataset 29 | seq_len: 140 # the length of the input sequence 30 | label_len: 0 # the length of the label sequence, 0 for classification 31 | pred_len: 0 # the length of the predicted sequence, 0 for classification 32 | enc_in: 1 # the number of variable numbers 33 | num_class: 5 # the number of classes 34 | c_out: None # the output variable numbers, 0 for classification 35 | ``` 36 | 37 | ### 3. Finetune your UniTS model 38 | 39 | #### Load Pretrained weights (Optional) 40 | You can load the pretrained SSL/Supervised UniTS model. 41 | Run [SSL Pretraining]() or [Supervised training]() scripts to get the pretrained checkpoints. 42 | Normally, SSL pretrained model has better transfer learning abilities. 43 | 44 | #### Setup finetuning script 45 | 46 | **Note: Remove captions before using the following scripts!** 47 | 48 | - Finetuning/Supervised training 49 | ```bash 50 | model_name=UniTS # Model name, UniTS 51 | exp_name=UniTS_supervised_x64 # Exp name 52 | wandb_mode=online # Use wandb to log the training, change to disabled if you don't want to use it 53 | project_name=supervised_learning # preject name in wandb 54 | 55 | random_port=$((RANDOM % 9000 + 1000)) 56 | 57 | # Supervised learning 58 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 59 | --is_training 1 \ # 1 for training, 0 for testing 60 | --model_id $exp_name \ 61 | --model $model_name \ 62 | --lradj supervised \ # You can define your own lr decay scheme in the adjust_learning_rate function of utils/tools.py 63 | --prompt_num 10 \ # The number of prompt tokens. 64 | --patch_len 16 \ # Patch size for each token in UniTS 65 | --stride 16 \ # Stride = patch size 66 | --e_layers 3 \ 67 | --d_model 64 \ 68 | --des 'Exp' \ 69 | --learning_rate 1e-4 \ # Tune the following hp for your datasets. Due to the high deverse nature of time series data, you might need to tune the hp for your new data. 70 | --weight_decay 5e-6 \ 71 | --train_epochs 5 \ 72 | --batch_size 32 \ # Real batch size = batch_size * acc_it 73 | --acc_it 32 \ 74 | --debug $wandb_mode \ 75 | --project_name $project_name \ 76 | --clip_grad 100 \ # Grad clip to avoid Nan. 77 | --pretrained_weight ckpt_path.pth \ # Path of pretrained ckpt if you want to finetune the model, otherwise just remove it 78 | --task_data_config_path data_provider/multi_task.yaml # Important: Change to your_own_data_config.yaml 79 | 80 | ``` 81 | 82 | - Prompt learning 83 | 84 | For prompt learning, only tokens are finetuned and the model are fixed. 85 | **You must load pretrained model weights.** 86 | ```bash 87 | # Prompt tuning 88 | torchrun --nnodes 1 --master_port $random_port run.py \ 89 | --is_training 1 \ 90 | --model_id $exp_name \ 91 | --model $model_name \ 92 | --lradj prompt_tuning \ 93 | --prompt_num 10 \ 94 | --patch_len 16 \ 95 | --stride 16 \ 96 | --e_layers 3 \ 97 | --d_model $d_model \ 98 | --des 'Exp' \ 99 | --itr 1 \ 100 | --learning_rate 3e-3 \ 101 | --weight_decay 0 \ 102 | --prompt_tune_epoch 2 \ # Number of epochs for prompt tuning 103 | --train_epochs 0 \ 104 | --acc_it 32 \ 105 | --debug $wandb_mode \ 106 | --project_name $ptune_name \ 107 | --clip_grad 100 \ 108 | --pretrained_weight auto \ # Path of pretrained ckpt, you must add it for prompt learning 109 | --task_data_config_path data_provider/multi_task.yaml # Important: Change to your_own_data_config.yaml 110 | ``` 111 | 112 | ### 113 | Feel free to open an issue if you have any problems in using our code. 114 | 115 | This doc will be updated. -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/data_provider/__init__.py -------------------------------------------------------------------------------- /data_provider/anomaly_detection.yaml: -------------------------------------------------------------------------------- 1 | task_dataset: 2 | MSL: 3 | task_name: anomaly_detection 4 | dataset_name: MSL 5 | dataset: MSL 6 | data: MSL 7 | root_path: ../dataset/MSL 8 | seq_len: 96 9 | label_len: 0 10 | pred_len: 0 11 | features: M 12 | embed: timeF 13 | enc_in: 55 14 | dec_in: 55 15 | c_out: 55 16 | 17 | PSM: 18 | task_name: anomaly_detection 19 | dataset_name: PSM 20 | dataset: PSM 21 | data: PSM 22 | root_path: ../dataset/PSM 23 | seq_len: 96 24 | label_len: 0 25 | pred_len: 0 26 | features: M 27 | embed: timeF 28 | enc_in: 25 29 | dec_in: 25 30 | c_out: 25 31 | 32 | SMAP: 33 | task_name: anomaly_detection 34 | dataset_name: SMAP 35 | dataset: SMAP 36 | data: SMAP 37 | root_path: ../dataset/SMAP 38 | seq_len: 96 39 | label_len: 0 40 | pred_len: 0 41 | features: M 42 | embed: timeF 43 | enc_in: 25 44 | dec_in: 25 45 | c_out: 25 46 | 47 | SMD: 48 | task_name: anomaly_detection 49 | dataset_name: SMD 50 | dataset: SMD 51 | data: SMD 52 | root_path: ../dataset/SMD 53 | seq_len: 96 54 | label_len: 0 55 | pred_len: 0 56 | features: M 57 | embed: timeF 58 | enc_in: 38 59 | dec_in: 38 60 | c_out: 38 61 | 62 | SWAT: 63 | task_name: anomaly_detection 64 | dataset_name: SWAT 65 | dataset: SWAT 66 | data: SWAT 67 | root_path: ../dataset/SWaT 68 | seq_len: 96 69 | label_len: 0 70 | pred_len: 0 71 | features: M 72 | embed: timeF 73 | enc_in: 51 74 | dec_in: 51 75 | c_out: 51 -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, PSMSegLoader, \ 2 | MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, GLUONTSDataset 3 | from data_provider.uea import collate_fn 4 | import torch 5 | from torch.utils.data import DataLoader, Subset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | data_dict = { 9 | 'ETTh1': Dataset_ETT_hour, 10 | 'ETTh2': Dataset_ETT_hour, 11 | 'ETTm1': Dataset_ETT_minute, 12 | 'ETTm2': Dataset_ETT_minute, 13 | 'custom': Dataset_Custom, 14 | # 'm4': Dataset_M4, Removed due to the LICENSE file constraints of m4.py 15 | 'PSM': PSMSegLoader, 16 | 'MSL': MSLSegLoader, 17 | 'SMAP': SMAPSegLoader, 18 | 'SMD': SMDSegLoader, 19 | 'SWAT': SWATSegLoader, 20 | 'UEA': UEAloader, 21 | # datasets from gluonts package: 22 | "gluonts": GLUONTSDataset, 23 | } 24 | 25 | 26 | def random_subset(dataset, pct, seed): 27 | generator = torch.Generator() 28 | generator.manual_seed(seed) 29 | idx = torch.randperm(len(dataset), generator=generator) 30 | return Subset(dataset, idx[:int(len(dataset) * pct)].long().numpy()) 31 | 32 | 33 | def data_provider(args, config, flag, ddp=False): # args, 34 | Data = data_dict[config['data']] 35 | timeenc = 0 if config['embed'] != 'timeF' else 1 36 | 37 | if flag == 'test': 38 | shuffle_flag = False 39 | drop_last = False 40 | if 'anomaly_detection' in config['task_name']: # working on one gpu 41 | batch_size = args.batch_size 42 | else: 43 | batch_size = 1 # bsz=1 for evaluation 44 | freq = args.freq 45 | else: 46 | shuffle_flag = True 47 | drop_last = True 48 | batch_size = args.batch_size # bsz for train and valid 49 | freq = args.freq 50 | 51 | if 'gluonts' in config['data']: 52 | # process gluonts dataset: 53 | data_set = Data( 54 | dataset_name=config['dataset_name'], 55 | size=(config['seq_len'], config['label_len'], config['pred_len']), 56 | path=config['root_path'], 57 | # Don't set dataset_writer 58 | features=config["features"], 59 | flag=flag, 60 | ) 61 | if args.subsample_pct is not None and flag == "train": 62 | data_set = random_subset( 63 | data_set, args.subsample_pct, args.fix_seed) 64 | 65 | data_loader = DataLoader( 66 | data_set, 67 | batch_size=batch_size, 68 | shuffle=shuffle_flag, 69 | num_workers=args.num_workers, 70 | drop_last=drop_last 71 | ) 72 | 73 | return data_set, data_loader 74 | 75 | timeenc = 0 if config['embed'] != 'timeF' else 1 76 | 77 | if 'anomaly_detection' in config['task_name']: 78 | drop_last = False 79 | data_set = Data( 80 | root_path=config['root_path'], 81 | win_size=config['seq_len'], 82 | flag=flag, 83 | ) 84 | if args.subsample_pct is not None and flag == "train": 85 | data_set = random_subset( 86 | data_set, args.subsample_pct, args.fix_seed) 87 | print("ddp mode is set to false for anomaly_detection", ddp, len(data_set)) 88 | data_loader = DataLoader( 89 | data_set, 90 | batch_size=batch_size, 91 | shuffle=False if ddp else shuffle_flag, 92 | num_workers=args.num_workers, 93 | sampler=DistributedSampler(data_set) if ddp else None, 94 | drop_last=drop_last) 95 | return data_set, data_loader 96 | elif 'classification' in config['task_name']: 97 | drop_last = False 98 | data_set = Data( 99 | root_path=config['root_path'], 100 | flag=flag, 101 | ) 102 | if args.subsample_pct is not None and flag == "train": 103 | data_set = random_subset( 104 | data_set, args.subsample_pct, args.fix_seed) 105 | print(flag, len(data_set)) 106 | data_loader = DataLoader( 107 | data_set, 108 | batch_size=batch_size, 109 | shuffle=False if ddp else shuffle_flag, 110 | num_workers=args.num_workers, 111 | drop_last=drop_last, 112 | sampler=DistributedSampler(data_set) if ddp else None, 113 | collate_fn=lambda x: collate_fn(x, max_len=config['seq_len']) 114 | ) 115 | return data_set, data_loader 116 | else: 117 | if config['data'] == 'm4': 118 | drop_last = False 119 | data_set = Data( 120 | root_path=config['root_path'], 121 | data_path=config['data_path'], 122 | flag=flag, 123 | size=[config['seq_len'], config['label_len'], config['pred_len']], 124 | features=config['features'], 125 | target=args.target, 126 | timeenc=timeenc, 127 | freq=freq, 128 | seasonal_patterns=config['seasonal_patterns'] if config['data'] == 'm4' else None 129 | ) 130 | if args.subsample_pct is not None and flag == "train": 131 | data_set = random_subset( 132 | data_set, args.subsample_pct, args.fix_seed) 133 | print(flag, len(data_set)) 134 | data_loader = DataLoader( 135 | data_set, 136 | batch_size=batch_size, 137 | shuffle=False if ddp else shuffle_flag, 138 | num_workers=args.num_workers, 139 | sampler=DistributedSampler(data_set) if ddp else None, 140 | drop_last=drop_last) 141 | return data_set, data_loader 142 | -------------------------------------------------------------------------------- /data_provider/fewshot_new_task.yaml: -------------------------------------------------------------------------------- 1 | task_dataset: 2 | 3 | CLS_ECG200: 4 | task_name: classification 5 | dataset: ECG200 6 | data: UEA 7 | embed: timeF 8 | root_path: ../dataset/ECG200 9 | seq_len: 96 10 | label_len: 0 11 | pred_len: 0 12 | enc_in: 1 13 | num_class: 2 14 | c_out: None 15 | 16 | CLS_Handwriting: 17 | task_name: classification 18 | dataset: Handwriting 19 | data: UEA 20 | embed: timeF 21 | root_path: ../dataset/Handwriting 22 | seq_len: 152 23 | label_len: 0 24 | pred_len: 0 25 | enc_in: 3 26 | num_class: 26 27 | c_out: None 28 | 29 | CLS_SelfRegulationSCP1: 30 | task_name: classification 31 | dataset: SelfRegulationSCP1 32 | data: UEA 33 | embed: timeF 34 | root_path: ../dataset/SelfRegulationSCP1 35 | seq_len: 896 36 | label_len: 0 37 | pred_len: 0 38 | enc_in: 6 39 | num_class: 2 40 | c_out: None 41 | 42 | 43 | CLS_RacketSports: 44 | task_name: classification 45 | dataset: RacketSports 46 | data: UEA 47 | embed: timeF 48 | root_path: ../dataset/UAE/RacketSports 49 | seq_len: 30 50 | label_len: 0 51 | pred_len: 0 52 | enc_in: 6 53 | num_class: 4 54 | c_out: None 55 | 56 | CLS_Epilepsy: 57 | task_name: classification 58 | dataset: Epilepsy 59 | data: UEA 60 | embed: timeF 61 | root_path: ../dataset/UAE/Epilepsy 62 | seq_len: 207 63 | label_len: 0 64 | pred_len: 0 65 | enc_in: 3 66 | num_class: 4 67 | c_out: None 68 | 69 | CLS_StarLightCurves: 70 | task_name: classification 71 | dataset: StarLightCurves 72 | data: UEA 73 | embed: timeF 74 | root_path: ../dataset/UCR/StarLightCurves 75 | seq_len: 1024 76 | label_len: 0 77 | pred_len: 0 78 | enc_in: 1 79 | num_class: 3 80 | c_out: None 81 | 82 | LTF_ETTh2_p96: 83 | task_name: long_term_forecast 84 | dataset: ETTh2 85 | data: ETTh2 86 | embed: timeF 87 | root_path: ../dataset/ETT-small/ 88 | data_path: ETTh1.csv 89 | features: M 90 | seq_len: 96 91 | label_len: 48 92 | pred_len: 96 93 | enc_in: 7 94 | dec_in: 7 95 | c_out: 7 96 | 97 | LTF_ETTh2_p192: 98 | task_name: long_term_forecast 99 | dataset: ETTh2 100 | data: ETTh2 101 | embed: timeF 102 | root_path: ../dataset/ETT-small/ 103 | data_path: ETTh2.csv 104 | features: M 105 | seq_len: 96 106 | label_len: 48 107 | pred_len: 192 108 | enc_in: 7 109 | dec_in: 7 110 | c_out: 7 111 | 112 | LTF_ETTh2_p336: 113 | task_name: long_term_forecast 114 | dataset: ETTh2 115 | data: ETTh2 116 | embed: timeF 117 | root_path: ../dataset/ETT-small/ 118 | data_path: ETTh2.csv 119 | features: M 120 | seq_len: 96 121 | label_len: 48 122 | pred_len: 336 123 | enc_in: 7 124 | dec_in: 7 125 | c_out: 7 126 | 127 | LTF_ETTh2_p720: 128 | task_name: long_term_forecast 129 | dataset: ETTh2 130 | data: ETTh2 131 | embed: timeF 132 | root_path: ../dataset/ETT-small/ 133 | data_path: ETTh2.csv 134 | features: M 135 | seq_len: 96 136 | label_len: 48 137 | pred_len: 720 138 | enc_in: 7 139 | dec_in: 7 140 | c_out: 7 141 | 142 | LTF_SaugeenRiverFlow: 143 | task_name: long_term_forecast 144 | dataset_name: saugeenday 145 | dataset: SaugeenRiverFlow 146 | data: gluonts 147 | root_path: ../dataset/gluonts 148 | seq_len: 48 149 | label_len: 0 150 | pred_len: 24 151 | features: M 152 | embed: timeF 153 | enc_in: 1 154 | dec_in: 1 155 | c_out: 1 156 | 157 | LTF_ETTm1_p96: 158 | task_name: long_term_forecast 159 | dataset: ETTm1 160 | data: ETTm1 161 | embed: timeF 162 | root_path: ../dataset/ETT-small/ 163 | data_path: ETTm1.csv 164 | features: M 165 | seq_len: 96 166 | label_len: 48 167 | pred_len: 96 168 | enc_in: 7 169 | dec_in: 7 170 | c_out: 7 171 | 172 | LTF_ETTm1_p192: 173 | task_name: long_term_forecast 174 | dataset: ETTm1 175 | data: ETTm1 176 | embed: timeF 177 | root_path: ../dataset/ETT-small/ 178 | data_path: ETTm1.csv 179 | features: M 180 | seq_len: 96 181 | label_len: 48 182 | pred_len: 192 183 | enc_in: 7 184 | dec_in: 7 185 | c_out: 7 186 | 187 | LTF_ETTm1_p336: 188 | task_name: long_term_forecast 189 | dataset: ETTm1 190 | data: ETTm1 191 | embed: timeF 192 | root_path: ../dataset/ETT-small/ 193 | data_path: ETTm1.csv 194 | features: M 195 | seq_len: 96 196 | label_len: 48 197 | pred_len: 336 198 | enc_in: 7 199 | dec_in: 7 200 | c_out: 7 201 | 202 | LTF_ETTm1_p720: 203 | task_name: long_term_forecast 204 | dataset: ETTm1 205 | data: ETTm1 206 | embed: timeF 207 | root_path: ../dataset/ETT-small/ 208 | data_path: ETTm1.csv 209 | features: M 210 | seq_len: 96 211 | label_len: 48 212 | pred_len: 720 213 | enc_in: 7 214 | dec_in: 7 215 | c_out: 7 216 | 217 | -------------------------------------------------------------------------------- /data_provider/imputation.yaml: -------------------------------------------------------------------------------- 1 | task_dataset: 2 | LTF_ECL_p96: 3 | task_name: imputation 4 | dataset: ECL 5 | data: custom 6 | embed: timeF 7 | root_path: ../dataset/electricity/ 8 | data_path: electricity.csv 9 | features: M 10 | seq_len: 96 11 | label_len: 0 12 | pred_len: 0 13 | enc_in: 321 14 | dec_in: 321 15 | c_out: 321 16 | 17 | LTF_ETTh1_p96: 18 | task_name: imputation 19 | dataset: ETTh1 20 | data: ETTh1 21 | embed: timeF 22 | root_path: ../dataset/ETT-small/ 23 | data_path: ETTh1.csv 24 | features: M 25 | seq_len: 96 26 | label_len: 0 27 | pred_len: 0 28 | enc_in: 7 29 | dec_in: 7 30 | c_out: 7 31 | 32 | LTF_Weather_p96: 33 | task_name: imputation 34 | dataset: Weather 35 | data: custom 36 | embed: timeF 37 | root_path: ../dataset/weather/ 38 | data_path: weather.csv 39 | features: M 40 | seq_len: 96 41 | label_len: 0 42 | pred_len: 0 43 | enc_in: 21 44 | dec_in: 21 45 | c_out: 21 46 | 47 | LTF_ETTh2_p96: 48 | task_name: imputation 49 | dataset: ETTh2 50 | data: ETTh2 51 | embed: timeF 52 | root_path: ../dataset/ETT-small/ 53 | data_path: ETTh2.csv 54 | features: M 55 | seq_len: 96 56 | label_len: 0 57 | pred_len: 0 58 | enc_in: 7 59 | dec_in: 7 60 | c_out: 7 61 | 62 | LTF_ETTm1_p96: 63 | task_name: imputation 64 | dataset: ETTm1 65 | data: ETTm1 66 | embed: timeF 67 | root_path: ../dataset/ETT-small/ 68 | data_path: ETTm1.csv 69 | features: M 70 | seq_len: 96 71 | label_len: 0 72 | pred_len: 0 73 | enc_in: 7 74 | dec_in: 7 75 | c_out: 7 76 | 77 | LTF_ETTm2_p96: 78 | task_name: imputation 79 | dataset: ETTm2 80 | data: ETTm2 81 | embed: timeF 82 | root_path: ../dataset/ETT-small/ 83 | data_path: ETTm2.csv 84 | features: M 85 | seq_len: 96 86 | label_len: 0 87 | pred_len: 0 88 | enc_in: 7 89 | dec_in: 7 90 | c_out: 7 91 | -------------------------------------------------------------------------------- /data_provider/m4.py: -------------------------------------------------------------------------------- 1 | # This file is removed due to LICENSE file constraints. 2 | # You can copy the m4.py from https://github.com/thuml/Time-Series-Library/blob/main/data_provider/m4.py -------------------------------------------------------------------------------- /data_provider/multi_task.yaml: -------------------------------------------------------------------------------- 1 | task_dataset: 2 | NN5_p112: 3 | task_name: long_term_forecast 4 | dataset_name: nn5_daily_without_missing 5 | dataset: NN5 6 | data: gluonts 7 | root_path: ../dataset/gluonts 8 | seq_len: 112 9 | label_len: 0 10 | pred_len: 112 11 | features: M 12 | embed: timeF 13 | enc_in: 111 14 | dec_in: 111 15 | c_out: 111 16 | 17 | LTF_ECL_p96: 18 | task_name: long_term_forecast 19 | dataset: ECL 20 | data: custom 21 | embed: timeF 22 | root_path: ../dataset/electricity/ 23 | data_path: electricity.csv 24 | features: M 25 | seq_len: 96 26 | label_len: 48 27 | pred_len: 96 28 | enc_in: 321 29 | dec_in: 321 30 | c_out: 321 31 | 32 | LTF_ECL_p192: 33 | task_name: long_term_forecast 34 | dataset: ECL 35 | data: custom 36 | embed: timeF 37 | root_path: ../dataset/electricity/ 38 | data_path: electricity.csv 39 | features: M 40 | seq_len: 96 41 | label_len: 48 42 | pred_len: 192 43 | enc_in: 321 44 | dec_in: 321 45 | c_out: 321 46 | 47 | LTF_ECL_p336: 48 | task_name: long_term_forecast 49 | dataset: ECL 50 | data: custom 51 | embed: timeF 52 | root_path: ../dataset/electricity/ 53 | data_path: electricity.csv 54 | features: M 55 | seq_len: 96 56 | label_len: 48 57 | pred_len: 336 58 | enc_in: 321 59 | dec_in: 321 60 | c_out: 321 61 | 62 | LTF_ECL_p720: 63 | task_name: long_term_forecast 64 | dataset: ECL 65 | data: custom 66 | embed: timeF 67 | root_path: ../dataset/electricity/ 68 | data_path: electricity.csv 69 | features: M 70 | seq_len: 96 71 | label_len: 48 72 | pred_len: 720 73 | enc_in: 321 74 | dec_in: 321 75 | c_out: 321 76 | 77 | LTF_ETTh1_p96: 78 | task_name: long_term_forecast 79 | dataset: ETTh1 80 | data: ETTh1 81 | embed: timeF 82 | root_path: ../dataset/ETT-small/ 83 | data_path: ETTh1.csv 84 | features: M 85 | seq_len: 96 86 | label_len: 48 87 | pred_len: 96 88 | enc_in: 7 89 | dec_in: 7 90 | c_out: 7 91 | 92 | LTF_ETTh1_p192: 93 | task_name: long_term_forecast 94 | dataset: ETTh1 95 | data: ETTh1 96 | embed: timeF 97 | root_path: ../dataset/ETT-small/ 98 | data_path: ETTh1.csv 99 | features: M 100 | seq_len: 96 101 | label_len: 48 102 | pred_len: 192 103 | enc_in: 7 104 | dec_in: 7 105 | c_out: 7 106 | 107 | LTF_ETTh1_p336: 108 | task_name: long_term_forecast 109 | dataset: ETTh1 110 | data: ETTh1 111 | embed: timeF 112 | root_path: ../dataset/ETT-small/ 113 | data_path: ETTh1.csv 114 | features: M 115 | seq_len: 96 116 | label_len: 48 117 | pred_len: 336 118 | enc_in: 7 119 | dec_in: 7 120 | c_out: 7 121 | 122 | LTF_ETTh1_p720: 123 | task_name: long_term_forecast 124 | dataset: ETTh1 125 | data: ETTh1 126 | embed: timeF 127 | root_path: ../dataset/ETT-small/ 128 | data_path: ETTh1.csv 129 | features: M 130 | seq_len: 96 131 | label_len: 48 132 | pred_len: 720 133 | enc_in: 7 134 | dec_in: 7 135 | c_out: 7 136 | 137 | LTF_Exchange_p192: 138 | task_name: long_term_forecast 139 | dataset: Exchange 140 | data: custom 141 | embed: timeF 142 | root_path: ../dataset/exchange_rate/ 143 | data_path: exchange_rate.csv 144 | features: M 145 | seq_len: 96 146 | label_len: 48 147 | pred_len: 192 148 | enc_in: 8 149 | dec_in: 8 150 | c_out: 8 151 | 152 | LTF_Exchange_p336: 153 | task_name: long_term_forecast 154 | dataset: Exchange 155 | data: custom 156 | embed: timeF 157 | root_path: ../dataset/exchange_rate/ 158 | data_path: exchange_rate.csv 159 | features: M 160 | seq_len: 96 161 | label_len: 48 162 | pred_len: 336 163 | enc_in: 8 164 | dec_in: 8 165 | c_out: 8 166 | 167 | LTF_ILI_p60: 168 | task_name: long_term_forecast 169 | dataset: ILI 170 | data: custom 171 | embed: timeF 172 | root_path: ../dataset/illness/ 173 | data_path: national_illness.csv 174 | features: M 175 | seq_len: 36 176 | label_len: 18 177 | pred_len: 60 178 | enc_in: 7 179 | dec_in: 7 180 | c_out: 7 181 | 182 | LTF_Traffic_p96: 183 | task_name: long_term_forecast 184 | dataset: Traffic 185 | data: custom 186 | embed: timeF 187 | root_path: ../dataset/traffic/ 188 | data_path: traffic.csv 189 | features: M 190 | seq_len: 96 191 | label_len: 48 192 | pred_len: 96 193 | enc_in: 862 194 | dec_in: 862 195 | c_out: 862 196 | 197 | LTF_Traffic_p192: 198 | task_name: long_term_forecast 199 | dataset: Traffic 200 | data: custom 201 | embed: timeF 202 | root_path: ../dataset/traffic/ 203 | data_path: traffic.csv 204 | features: M 205 | seq_len: 96 206 | label_len: 48 207 | pred_len: 192 208 | enc_in: 862 209 | dec_in: 862 210 | c_out: 862 211 | 212 | LTF_Traffic_p336: 213 | task_name: long_term_forecast 214 | dataset: Traffic 215 | data: custom 216 | embed: timeF 217 | root_path: ../dataset/traffic/ 218 | data_path: traffic.csv 219 | features: M 220 | seq_len: 96 221 | label_len: 48 222 | pred_len: 336 223 | enc_in: 862 224 | dec_in: 862 225 | c_out: 862 226 | 227 | LTF_Traffic_p720: 228 | task_name: long_term_forecast 229 | dataset: Traffic 230 | data: custom 231 | embed: timeF 232 | root_path: ../dataset/traffic/ 233 | data_path: traffic.csv 234 | features: M 235 | seq_len: 96 236 | label_len: 48 237 | pred_len: 720 238 | enc_in: 862 239 | dec_in: 862 240 | c_out: 862 241 | 242 | LTF_Weather_p96: 243 | task_name: long_term_forecast 244 | dataset: Weather 245 | data: custom 246 | embed: timeF 247 | root_path: ../dataset/weather/ 248 | data_path: weather.csv 249 | features: M 250 | seq_len: 96 251 | label_len: 48 252 | pred_len: 96 253 | enc_in: 21 254 | dec_in: 21 255 | c_out: 21 256 | 257 | LTF_Weather_p192: 258 | task_name: long_term_forecast 259 | dataset: Weather 260 | data: custom 261 | embed: timeF 262 | root_path: ../dataset/weather/ 263 | data_path: weather.csv 264 | features: M 265 | seq_len: 96 266 | label_len: 48 267 | pred_len: 192 268 | enc_in: 21 269 | dec_in: 21 270 | c_out: 21 271 | 272 | LTF_Weather_p336: 273 | task_name: long_term_forecast 274 | dataset: Weather 275 | data: custom 276 | embed: timeF 277 | root_path: ../dataset/weather/ 278 | data_path: weather.csv 279 | features: M 280 | seq_len: 96 281 | label_len: 48 282 | pred_len: 336 283 | enc_in: 21 284 | dec_in: 21 285 | c_out: 21 286 | 287 | LTF_Weather_p720: 288 | task_name: long_term_forecast 289 | dataset: Weather 290 | data: custom 291 | embed: timeF 292 | root_path: ../dataset/weather/ 293 | data_path: weather.csv 294 | features: M 295 | seq_len: 96 296 | label_len: 48 297 | pred_len: 720 298 | enc_in: 21 299 | dec_in: 21 300 | c_out: 21 301 | 302 | CLS_Heartbeat: 303 | task_name: classification 304 | dataset: Heartbeat 305 | data: UEA 306 | embed: timeF 307 | root_path: ../dataset/Heartbeat/ 308 | seq_len: 405 309 | label_len: 0 310 | pred_len: 0 311 | enc_in: 61 312 | num_class: 2 313 | c_out: None 314 | 315 | CLS_JapaneseVowels: 316 | task_name: classification 317 | dataset: JapaneseVowels 318 | data: UEA 319 | embed: timeF 320 | root_path: ../dataset/JapaneseVowels/ 321 | seq_len: 29 322 | label_len: 0 323 | pred_len: 0 324 | enc_in: 12 325 | num_class: 9 326 | c_out: None 327 | 328 | CLS_PEMS-SF: 329 | task_name: classification 330 | dataset: PEMS-SF 331 | data: UEA 332 | embed: timeF 333 | root_path: ../dataset/PEMS-SF/ 334 | seq_len: 144 335 | label_len: 0 336 | pred_len: 0 337 | enc_in: 963 338 | num_class: 7 339 | c_out: None 340 | 341 | CLS_SelfRegulationSCP2: 342 | task_name: classification 343 | dataset: SelfRegulationSCP2 344 | data: UEA 345 | embed: timeF 346 | root_path: ../dataset/SelfRegulationSCP2/ 347 | seq_len: 1152 348 | label_len: 0 349 | pred_len: 0 350 | enc_in: 7 351 | num_class: 2 352 | c_out: None 353 | 354 | CLS_SpokenArabicDigits: 355 | task_name: classification 356 | dataset: SpokenArabicDigits 357 | data: UEA 358 | embed: timeF 359 | root_path: ../dataset/SpokenArabicDigits/ 360 | seq_len: 93 361 | label_len: 0 362 | pred_len: 0 363 | enc_in: 13 364 | num_class: 10 365 | c_out: None 366 | 367 | CLS_UWaveGestureLibrary: 368 | task_name: classification 369 | dataset: UWaveGestureLibrary 370 | data: UEA 371 | embed: timeF 372 | root_path: ../dataset/UWaveGestureLibrary/ 373 | seq_len: 315 374 | label_len: 0 375 | pred_len: 0 376 | enc_in: 3 377 | num_class: 8 378 | c_out: None 379 | 380 | CLS_ECG5000: 381 | task_name: classification 382 | dataset: ECG5000 383 | data: UEA 384 | embed: timeF 385 | root_path: ../dataset/UCR/ECG5000 386 | seq_len: 140 387 | label_len: 0 388 | pred_len: 0 389 | enc_in: 1 390 | num_class: 5 391 | c_out: None 392 | 393 | CLS_NonInvasiveFetalECGThorax1: 394 | task_name: classification 395 | dataset: NonInvasiveFetalECGThorax1 396 | data: UEA 397 | embed: timeF 398 | root_path: ../dataset/UCR/NonInvasiveFetalECGThorax1 399 | seq_len: 750 400 | label_len: 0 401 | pred_len: 0 402 | enc_in: 1 403 | num_class: 52 404 | c_out: None 405 | 406 | CLS_Blink: 407 | task_name: classification 408 | dataset: Blink 409 | data: UEA 410 | embed: timeF 411 | root_path: ../dataset/Blink 412 | seq_len: 510 413 | label_len: 0 414 | pred_len: 0 415 | enc_in: 4 416 | num_class: 2 417 | c_out: None 418 | 419 | CLS_FaceDetection: 420 | task_name: classification 421 | dataset: FaceDetection 422 | data: UEA 423 | embed: timeF 424 | root_path: ../dataset/FaceDetection 425 | seq_len: 62 426 | label_len: 0 427 | pred_len: 0 428 | enc_in: 144 429 | num_class: 2 430 | c_out: None 431 | 432 | CLS_ElectricDevices: 433 | task_name: classification 434 | dataset: ElectricDevices 435 | data: UEA 436 | embed: timeF 437 | root_path: ../dataset/UCR/ElectricDevices 438 | seq_len: 96 439 | label_len: 0 440 | pred_len: 0 441 | enc_in: 1 442 | num_class: 7 443 | c_out: None 444 | 445 | CLS_Trace: 446 | task_name: classification 447 | dataset: Trace 448 | data: UEA 449 | embed: timeF 450 | root_path: ../dataset/UCR/Trace 451 | seq_len: 275 452 | label_len: 0 453 | pred_len: 0 454 | enc_in: 1 455 | num_class: 4 456 | c_out: None 457 | 458 | CLS_FordB: 459 | task_name: classification 460 | dataset: FordB 461 | data: UEA 462 | embed: timeF 463 | root_path: ../dataset/UCR/FordB 464 | seq_len: 500 465 | label_len: 0 466 | pred_len: 0 467 | enc_in: 1 468 | num_class: 2 469 | c_out: None 470 | 471 | CLS_MotionSenseHAR: 472 | task_name: classification 473 | dataset: MotionSenseHAR 474 | data: UEA 475 | embed: timeF 476 | root_path: ../dataset/MotionSenseHAR 477 | seq_len: 200 478 | label_len: 0 479 | pred_len: 0 480 | enc_in: 12 481 | num_class: 6 482 | c_out: None 483 | 484 | CLS_EMOPain: 485 | task_name: classification 486 | dataset: EMOPain 487 | data: UEA 488 | embed: timeF 489 | root_path: ../dataset/EMOPain 490 | seq_len: 180 491 | label_len: 0 492 | pred_len: 0 493 | enc_in: 30 494 | num_class: 3 495 | c_out: None 496 | 497 | CLS_Chinatown: 498 | task_name: classification 499 | dataset: Chinatown 500 | data: UEA 501 | embed: timeF 502 | root_path: ../dataset/UCR/Chinatown 503 | seq_len: 24 504 | label_len: 0 505 | pred_len: 0 506 | enc_in: 1 507 | num_class: 2 508 | c_out: None 509 | 510 | CLS_MelbournePedestrian: 511 | task_name: classification 512 | dataset: MelbournePedestrian 513 | data: UEA 514 | embed: timeF 515 | root_path: ../dataset/UCR/MelbournePedestrian 516 | seq_len: 24 517 | label_len: 0 518 | pred_len: 0 519 | enc_in: 1 520 | num_class: 10 521 | c_out: None 522 | 523 | CLS_SharePriceIncrease: 524 | task_name: classification 525 | dataset: SharePriceIncrease 526 | data: UEA 527 | embed: timeF 528 | root_path: ../dataset/SharePriceIncrease 529 | seq_len: 60 530 | label_len: 0 531 | pred_len: 0 532 | enc_in: 1 533 | num_class: 2 534 | c_out: None -------------------------------------------------------------------------------- /data_provider/multi_task_pretrain.yaml: -------------------------------------------------------------------------------- 1 | task_dataset: 2 | NN5_p112: 3 | task_name: pretrain_long_term_forecast 4 | dataset_name: nn5_daily_without_missing 5 | dataset: NN5 6 | data: gluonts 7 | root_path: ../dataset/gluonts 8 | seq_len: 224 9 | label_len: 0 10 | pred_len: 0 11 | features: M 12 | embed: timeF 13 | enc_in: 111 14 | dec_in: 111 15 | c_out: 111 16 | 17 | LTF_ECL_p96: 18 | task_name: pretrain_long_term_forecast 19 | dataset: ECL 20 | data: custom 21 | embed: timeF 22 | root_path: ../dataset/electricity/ 23 | data_path: electricity.csv 24 | features: M 25 | seq_len: 192 26 | label_len: 48 27 | pred_len: 0 28 | enc_in: 321 29 | dec_in: 321 30 | c_out: 321 31 | 32 | LTF_ECL_p192: 33 | task_name: pretrain_long_term_forecast 34 | dataset: ECL 35 | data: custom 36 | embed: timeF 37 | root_path: ../dataset/electricity/ 38 | data_path: electricity.csv 39 | features: M 40 | seq_len: 288 41 | label_len: 48 42 | pred_len: 0 43 | enc_in: 321 44 | dec_in: 321 45 | c_out: 321 46 | 47 | LTF_ECL_p336: 48 | task_name: pretrain_long_term_forecast 49 | dataset: ECL 50 | data: custom 51 | embed: timeF 52 | root_path: ../dataset/electricity/ 53 | data_path: electricity.csv 54 | features: M 55 | seq_len: 432 56 | label_len: 48 57 | pred_len: 0 58 | enc_in: 321 59 | dec_in: 321 60 | c_out: 321 61 | 62 | LTF_ECL_p720: 63 | task_name: pretrain_long_term_forecast 64 | dataset: ECL 65 | data: custom 66 | embed: timeF 67 | root_path: ../dataset/electricity/ 68 | data_path: electricity.csv 69 | features: M 70 | seq_len: 816 71 | label_len: 48 72 | pred_len: 0 73 | enc_in: 321 74 | dec_in: 321 75 | c_out: 321 76 | 77 | LTF_ETTh1_p96: 78 | task_name: pretrain_long_term_forecast 79 | dataset: ETTh1 80 | data: ETTh1 81 | embed: timeF 82 | root_path: ../dataset/ETT-small/ 83 | data_path: ETTh1.csv 84 | features: M 85 | seq_len: 192 86 | label_len: 48 87 | pred_len: 0 88 | enc_in: 7 89 | dec_in: 7 90 | c_out: 7 91 | 92 | LTF_ETTh1_p192: 93 | task_name: pretrain_long_term_forecast 94 | dataset: ETTh1 95 | data: ETTh1 96 | embed: timeF 97 | root_path: ../dataset/ETT-small/ 98 | data_path: ETTh1.csv 99 | features: M 100 | seq_len: 288 101 | label_len: 48 102 | pred_len: 0 103 | enc_in: 7 104 | dec_in: 7 105 | c_out: 7 106 | 107 | LTF_ETTh1_p336: 108 | task_name: pretrain_long_term_forecast 109 | dataset: ETTh1 110 | data: ETTh1 111 | embed: timeF 112 | root_path: ../dataset/ETT-small/ 113 | data_path: ETTh1.csv 114 | features: M 115 | seq_len: 432 116 | label_len: 48 117 | pred_len: 0 118 | enc_in: 7 119 | dec_in: 7 120 | c_out: 7 121 | 122 | LTF_ETTh1_p720: 123 | task_name: pretrain_long_term_forecast 124 | dataset: ETTh1 125 | data: ETTh1 126 | embed: timeF 127 | root_path: ../dataset/ETT-small/ 128 | data_path: ETTh1.csv 129 | features: M 130 | seq_len: 816 131 | label_len: 48 132 | pred_len: 0 133 | enc_in: 7 134 | dec_in: 7 135 | c_out: 7 136 | 137 | LTF_Exchange_p192: 138 | task_name: pretrain_long_term_forecast 139 | dataset: Exchange 140 | data: custom 141 | embed: timeF 142 | root_path: ../dataset/exchange_rate/ 143 | data_path: exchange_rate.csv 144 | features: M 145 | seq_len: 288 146 | label_len: 48 147 | pred_len: 0 148 | enc_in: 8 149 | dec_in: 8 150 | c_out: 8 151 | 152 | LTF_Exchange_p336: 153 | task_name: pretrain_long_term_forecast 154 | dataset: Exchange 155 | data: custom 156 | embed: timeF 157 | root_path: ../dataset/exchange_rate/ 158 | data_path: exchange_rate.csv 159 | features: M 160 | seq_len: 432 161 | label_len: 48 162 | pred_len: 0 163 | enc_in: 8 164 | dec_in: 8 165 | c_out: 8 166 | 167 | LTF_ILI_p60: 168 | task_name: pretrain_long_term_forecast 169 | dataset: ILI 170 | data: custom 171 | embed: timeF 172 | root_path: ../dataset/illness/ 173 | data_path: national_illness.csv 174 | features: M 175 | seq_len: 96 176 | label_len: 18 177 | pred_len: 0 178 | enc_in: 7 179 | dec_in: 7 180 | c_out: 7 181 | 182 | LTF_Traffic_p96: 183 | task_name: pretrain_long_term_forecast 184 | dataset: Traffic 185 | data: custom 186 | embed: timeF 187 | root_path: ../dataset/traffic/ 188 | data_path: traffic.csv 189 | features: M 190 | seq_len: 192 191 | label_len: 48 192 | pred_len: 0 193 | enc_in: 862 194 | dec_in: 862 195 | c_out: 862 196 | 197 | LTF_Traffic_p192: 198 | task_name: pretrain_long_term_forecast 199 | dataset: Traffic 200 | data: custom 201 | embed: timeF 202 | root_path: ../dataset/traffic/ 203 | data_path: traffic.csv 204 | features: M 205 | seq_len: 288 206 | label_len: 48 207 | pred_len: 0 208 | enc_in: 862 209 | dec_in: 862 210 | c_out: 862 211 | 212 | LTF_Traffic_p336: 213 | task_name: pretrain_long_term_forecast 214 | dataset: Traffic 215 | data: custom 216 | embed: timeF 217 | root_path: ../dataset/traffic/ 218 | data_path: traffic.csv 219 | features: M 220 | seq_len: 432 221 | label_len: 48 222 | pred_len: 0 223 | enc_in: 862 224 | dec_in: 862 225 | c_out: 862 226 | 227 | LTF_Traffic_p720: 228 | task_name: pretrain_long_term_forecast 229 | dataset: Traffic 230 | data: custom 231 | embed: timeF 232 | root_path: ../dataset/traffic/ 233 | data_path: traffic.csv 234 | features: M 235 | seq_len: 816 236 | label_len: 48 237 | pred_len: 0 238 | enc_in: 862 239 | dec_in: 862 240 | c_out: 862 241 | 242 | LTF_Weather_p96: 243 | task_name: pretrain_long_term_forecast 244 | dataset: Weather 245 | data: custom 246 | embed: timeF 247 | root_path: ../dataset/weather/ 248 | data_path: weather.csv 249 | features: M 250 | seq_len: 192 251 | label_len: 48 252 | pred_len: 0 253 | enc_in: 21 254 | dec_in: 21 255 | c_out: 21 256 | 257 | LTF_Weather_p192: 258 | task_name: pretrain_long_term_forecast 259 | dataset: Weather 260 | data: custom 261 | embed: timeF 262 | root_path: ../dataset/weather/ 263 | data_path: weather.csv 264 | features: M 265 | seq_len: 288 266 | label_len: 48 267 | pred_len: 0 268 | enc_in: 21 269 | dec_in: 21 270 | c_out: 21 271 | 272 | LTF_Weather_p336: 273 | task_name: pretrain_long_term_forecast 274 | dataset: Weather 275 | data: custom 276 | embed: timeF 277 | root_path: ../dataset/weather/ 278 | data_path: weather.csv 279 | features: M 280 | seq_len: 432 281 | label_len: 48 282 | pred_len: 0 283 | enc_in: 21 284 | dec_in: 21 285 | c_out: 21 286 | 287 | LTF_Weather_p720: 288 | task_name: pretrain_long_term_forecast 289 | dataset: Weather 290 | data: custom 291 | embed: timeF 292 | root_path: ../dataset/weather/ 293 | data_path: weather.csv 294 | features: M 295 | seq_len: 816 296 | label_len: 48 297 | pred_len: 0 298 | enc_in: 21 299 | dec_in: 21 300 | c_out: 21 301 | 302 | CLS_Heartbeat: 303 | task_name: pretrain_classification 304 | dataset: Heartbeat 305 | data: UEA 306 | embed: timeF 307 | root_path: ../dataset/Heartbeat/ 308 | seq_len: 405 309 | label_len: 0 310 | pred_len: 0 311 | enc_in: 61 312 | num_class: 2 313 | c_out: None 314 | 315 | CLS_JapaneseVowels: 316 | task_name: pretrain_classification 317 | dataset: JapaneseVowels 318 | data: UEA 319 | embed: timeF 320 | root_path: ../dataset/JapaneseVowels/ 321 | seq_len: 29 322 | label_len: 0 323 | pred_len: 0 324 | enc_in: 12 325 | num_class: 9 326 | c_out: None 327 | 328 | CLS_PEMS-SF: 329 | task_name: pretrain_classification 330 | dataset: PEMS-SF 331 | data: UEA 332 | embed: timeF 333 | root_path: ../dataset/PEMS-SF/ 334 | seq_len: 144 335 | label_len: 0 336 | pred_len: 0 337 | enc_in: 963 338 | num_class: 7 339 | c_out: None 340 | 341 | CLS_SelfRegulationSCP2: 342 | task_name: pretrain_classification 343 | dataset: SelfRegulationSCP2 344 | data: UEA 345 | embed: timeF 346 | root_path: ../dataset/SelfRegulationSCP2/ 347 | seq_len: 1152 348 | label_len: 0 349 | pred_len: 0 350 | enc_in: 7 351 | num_class: 2 352 | c_out: None 353 | 354 | CLS_SpokenArabicDigits: 355 | task_name: pretrain_classification 356 | dataset: SpokenArabicDigits 357 | data: UEA 358 | embed: timeF 359 | root_path: ../dataset/SpokenArabicDigits/ 360 | seq_len: 93 361 | label_len: 0 362 | pred_len: 0 363 | enc_in: 13 364 | num_class: 10 365 | c_out: None 366 | 367 | CLS_UWaveGestureLibrary: 368 | task_name: pretrain_classification 369 | dataset: UWaveGestureLibrary 370 | data: UEA 371 | embed: timeF 372 | root_path: ../dataset/UWaveGestureLibrary/ 373 | seq_len: 315 374 | label_len: 0 375 | pred_len: 0 376 | enc_in: 3 377 | num_class: 8 378 | c_out: None 379 | 380 | CLS_ECG5000: 381 | task_name: pretrain_classification 382 | dataset: ECG5000 383 | data: UEA 384 | embed: timeF 385 | root_path: ../dataset/UCR/ECG5000 386 | seq_len: 140 387 | label_len: 0 388 | pred_len: 0 389 | enc_in: 1 390 | num_class: 5 391 | c_out: None 392 | 393 | CLS_NonInvasiveFetalECGThorax1: 394 | task_name: pretrain_classification 395 | dataset: NonInvasiveFetalECGThorax1 396 | data: UEA 397 | embed: timeF 398 | root_path: ../dataset/UCR/NonInvasiveFetalECGThorax1 399 | seq_len: 750 400 | label_len: 0 401 | pred_len: 0 402 | enc_in: 1 403 | num_class: 52 404 | c_out: None 405 | 406 | CLS_Blink: 407 | task_name: pretrain_classification 408 | dataset: Blink 409 | data: UEA 410 | embed: timeF 411 | root_path: ../dataset/Blink 412 | seq_len: 510 413 | label_len: 0 414 | pred_len: 0 415 | enc_in: 4 416 | num_class: 2 417 | c_out: None 418 | 419 | CLS_FaceDetection: 420 | task_name: pretrain_classification 421 | dataset: FaceDetection 422 | data: UEA 423 | embed: timeF 424 | root_path: ../dataset/FaceDetection 425 | seq_len: 62 426 | label_len: 0 427 | pred_len: 0 428 | enc_in: 144 429 | num_class: 2 430 | c_out: None 431 | 432 | CLS_ElectricDevices: 433 | task_name: pretrain_classification 434 | dataset: ElectricDevices 435 | data: UEA 436 | embed: timeF 437 | root_path: ../dataset/UCR/ElectricDevices 438 | seq_len: 96 439 | label_len: 0 440 | pred_len: 0 441 | enc_in: 1 442 | num_class: 7 443 | c_out: None 444 | 445 | CLS_Trace: 446 | task_name: pretrain_classification 447 | dataset: Trace 448 | data: UEA 449 | embed: timeF 450 | root_path: ../dataset/UCR/Trace 451 | seq_len: 275 452 | label_len: 0 453 | pred_len: 0 454 | enc_in: 1 455 | num_class: 4 456 | c_out: None 457 | 458 | CLS_FordB: 459 | task_name: pretrain_classification 460 | dataset: FordB 461 | data: UEA 462 | embed: timeF 463 | root_path: ../dataset/UCR/FordB 464 | seq_len: 500 465 | label_len: 0 466 | pred_len: 0 467 | enc_in: 1 468 | num_class: 2 469 | c_out: None 470 | 471 | CLS_MotionSenseHAR: 472 | task_name: pretrain_classification 473 | dataset: MotionSenseHAR 474 | data: UEA 475 | embed: timeF 476 | root_path: ../dataset/MotionSenseHAR 477 | seq_len: 200 478 | label_len: 0 479 | pred_len: 0 480 | enc_in: 12 481 | num_class: 6 482 | c_out: None 483 | 484 | CLS_EMOPain: 485 | task_name: pretrain_classification 486 | dataset: EMOPain 487 | data: UEA 488 | embed: timeF 489 | root_path: ../dataset/EMOPain 490 | seq_len: 180 491 | label_len: 0 492 | pred_len: 0 493 | enc_in: 30 494 | num_class: 3 495 | c_out: None 496 | 497 | CLS_Chinatown: 498 | task_name: pretrain_classification 499 | dataset: Chinatown 500 | data: UEA 501 | embed: timeF 502 | root_path: ../dataset/UCR/Chinatown 503 | seq_len: 24 504 | label_len: 0 505 | pred_len: 0 506 | enc_in: 1 507 | num_class: 2 508 | c_out: None 509 | 510 | CLS_MelbournePedestrian: 511 | task_name: pretrain_classification 512 | dataset: MelbournePedestrian 513 | data: UEA 514 | embed: timeF 515 | root_path: ../dataset/UCR/MelbournePedestrian 516 | seq_len: 24 517 | label_len: 0 518 | pred_len: 0 519 | enc_in: 1 520 | num_class: 10 521 | c_out: None 522 | 523 | CLS_SharePriceIncrease: 524 | task_name: pretrain_classification 525 | dataset: SharePriceIncrease 526 | data: UEA 527 | embed: timeF 528 | root_path: ../dataset/SharePriceIncrease 529 | seq_len: 60 530 | label_len: 0 531 | pred_len: 0 532 | enc_in: 1 533 | num_class: 2 534 | c_out: None -------------------------------------------------------------------------------- /data_provider/multitask_zero_shot_new_length.yaml: -------------------------------------------------------------------------------- 1 | task_dataset: 2 | LTF_ECL_p96: 3 | task_name: long_term_forecast 4 | dataset: ECL 5 | data: custom 6 | embed: timeF 7 | root_path: ../dataset/electricity/ 8 | data_path: electricity.csv 9 | features: M 10 | seq_len: 96 11 | label_len: 48 12 | pred_len: 96 13 | enc_in: 321 14 | dec_in: 321 15 | c_out: 321 16 | 17 | LTF_ECL_p192: 18 | task_name: long_term_forecast 19 | dataset: ECL 20 | data: custom 21 | embed: timeF 22 | root_path: ../dataset/electricity/ 23 | data_path: electricity.csv 24 | features: M 25 | seq_len: 96 26 | label_len: 48 27 | pred_len: 192 28 | enc_in: 321 29 | dec_in: 321 30 | c_out: 321 31 | 32 | LTF_ECL_p336: 33 | task_name: long_term_forecast 34 | dataset: ECL 35 | data: custom 36 | embed: timeF 37 | root_path: ../dataset/electricity/ 38 | data_path: electricity.csv 39 | features: M 40 | seq_len: 96 41 | label_len: 48 42 | pred_len: 336 43 | enc_in: 321 44 | dec_in: 321 45 | c_out: 321 46 | 47 | LTF_ETTh1_p96: 48 | task_name: long_term_forecast 49 | dataset: ETTh1 50 | data: ETTh1 51 | embed: timeF 52 | root_path: ../dataset/ETT-small/ 53 | data_path: ETTh1.csv 54 | features: M 55 | seq_len: 96 56 | label_len: 48 57 | pred_len: 96 58 | enc_in: 7 59 | dec_in: 7 60 | c_out: 7 61 | 62 | LTF_ETTh1_p192: 63 | task_name: long_term_forecast 64 | dataset: ETTh1 65 | data: ETTh1 66 | embed: timeF 67 | root_path: ../dataset/ETT-small/ 68 | data_path: ETTh1.csv 69 | features: M 70 | seq_len: 96 71 | label_len: 48 72 | pred_len: 192 73 | enc_in: 7 74 | dec_in: 7 75 | c_out: 7 76 | 77 | LTF_ETTh1_p336: 78 | task_name: long_term_forecast 79 | dataset: ETTh1 80 | data: ETTh1 81 | embed: timeF 82 | root_path: ../dataset/ETT-small/ 83 | data_path: ETTh1.csv 84 | features: M 85 | seq_len: 96 86 | label_len: 48 87 | pred_len: 336 88 | enc_in: 7 89 | dec_in: 7 90 | c_out: 7 91 | 92 | LTF_Exchange_p192: 93 | task_name: long_term_forecast 94 | dataset: Exchange 95 | data: custom 96 | embed: timeF 97 | root_path: ../dataset/exchange_rate/ 98 | data_path: exchange_rate.csv 99 | features: M 100 | seq_len: 96 101 | label_len: 48 102 | pred_len: 192 103 | enc_in: 8 104 | dec_in: 8 105 | c_out: 8 106 | 107 | LTF_Exchange_p336: 108 | task_name: long_term_forecast 109 | dataset: Exchange 110 | data: custom 111 | embed: timeF 112 | root_path: ../dataset/exchange_rate/ 113 | data_path: exchange_rate.csv 114 | features: M 115 | seq_len: 96 116 | label_len: 48 117 | pred_len: 336 118 | enc_in: 8 119 | dec_in: 8 120 | c_out: 8 121 | 122 | LTF_Traffic_p96: 123 | task_name: long_term_forecast 124 | dataset: Traffic 125 | data: custom 126 | embed: timeF 127 | root_path: ../dataset/traffic/ 128 | data_path: traffic.csv 129 | features: M 130 | seq_len: 96 131 | label_len: 48 132 | pred_len: 96 133 | enc_in: 862 134 | dec_in: 862 135 | c_out: 862 136 | 137 | LTF_Traffic_p192: 138 | task_name: long_term_forecast 139 | dataset: Traffic 140 | data: custom 141 | embed: timeF 142 | root_path: ../dataset/traffic/ 143 | data_path: traffic.csv 144 | features: M 145 | seq_len: 96 146 | label_len: 48 147 | pred_len: 192 148 | enc_in: 862 149 | dec_in: 862 150 | c_out: 862 151 | 152 | LTF_Traffic_p336: 153 | task_name: long_term_forecast 154 | dataset: Traffic 155 | data: custom 156 | embed: timeF 157 | root_path: ../dataset/traffic/ 158 | data_path: traffic.csv 159 | features: M 160 | seq_len: 96 161 | label_len: 48 162 | pred_len: 336 163 | enc_in: 862 164 | dec_in: 862 165 | c_out: 862 166 | 167 | LTF_Weather_p96: 168 | task_name: long_term_forecast 169 | dataset: Weather 170 | data: custom 171 | embed: timeF 172 | root_path: ../dataset/weather/ 173 | data_path: weather.csv 174 | features: M 175 | seq_len: 96 176 | label_len: 48 177 | pred_len: 96 178 | enc_in: 21 179 | dec_in: 21 180 | c_out: 21 181 | 182 | LTF_Weather_p192: 183 | task_name: long_term_forecast 184 | dataset: Weather 185 | data: custom 186 | embed: timeF 187 | root_path: ../dataset/weather/ 188 | data_path: weather.csv 189 | features: M 190 | seq_len: 96 191 | label_len: 48 192 | pred_len: 192 193 | enc_in: 21 194 | dec_in: 21 195 | c_out: 21 196 | 197 | LTF_Weather_p336: 198 | task_name: long_term_forecast 199 | dataset: Weather 200 | data: custom 201 | embed: timeF 202 | root_path: ../dataset/weather/ 203 | data_path: weather.csv 204 | features: M 205 | seq_len: 96 206 | label_len: 48 207 | pred_len: 336 208 | enc_in: 21 209 | dec_in: 21 210 | c_out: 21 -------------------------------------------------------------------------------- /data_provider/uea.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | 7 | def collate_fn(data, max_len=None): 8 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create 9 | Args: 10 | data: len(batch_size) list of tuples (X, y). 11 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length. 12 | - y: torch tensor of shape (num_labels,) : class indices or numerical targets 13 | (for classification or regression, respectively). num_labels > 1 for multi-task models 14 | max_len: global fixed sequence length. Used for architectures requiring fixed length input, 15 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s 16 | Returns: 17 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input) 18 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output) 19 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor 20 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values 21 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding 22 | """ 23 | 24 | batch_size = len(data) 25 | features, labels = zip(*data) 26 | 27 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension) 28 | lengths = [X.shape[0] for X in features] # original sequence length for each time series 29 | if max_len is None: 30 | max_len = max(lengths) 31 | 32 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim) 33 | for i in range(batch_size): 34 | end = min(lengths[i], max_len) 35 | X[i, :end, :] = features[i][:end, :] 36 | 37 | targets = torch.stack(labels, dim=0) # (batch_size, num_labels) 38 | 39 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), 40 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep 41 | 42 | return X, targets, padding_masks 43 | 44 | 45 | def padding_mask(lengths, max_len=None): 46 | """ 47 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths, 48 | where 1 means keep element at this position (time step) 49 | """ 50 | batch_size = lengths.numel() 51 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types 52 | return (torch.arange(0, max_len, device=lengths.device) 53 | .type_as(lengths) 54 | .repeat(batch_size, 1) 55 | .lt(lengths.unsqueeze(1))) 56 | 57 | 58 | class Normalizer(object): 59 | """ 60 | Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization. 61 | """ 62 | 63 | def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None): 64 | """ 65 | Args: 66 | norm_type: choose from: 67 | "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps) 68 | "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows) 69 | mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values 70 | """ 71 | 72 | self.norm_type = norm_type 73 | self.mean = mean 74 | self.std = std 75 | self.min_val = min_val 76 | self.max_val = max_val 77 | 78 | def normalize(self, df): 79 | """ 80 | Args: 81 | df: input dataframe 82 | Returns: 83 | df: normalized dataframe 84 | """ 85 | if self.norm_type == "standardization": 86 | if self.mean is None: 87 | self.mean = df.mean() 88 | self.std = df.std() 89 | return (df - self.mean) / (self.std + np.finfo(float).eps) 90 | 91 | elif self.norm_type == "minmax": 92 | if self.max_val is None: 93 | self.max_val = df.max() 94 | self.min_val = df.min() 95 | return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps) 96 | 97 | elif self.norm_type == "per_sample_std": 98 | grouped = df.groupby(by=df.index) 99 | return (df - grouped.transform('mean')) / grouped.transform('std') 100 | 101 | elif self.norm_type == "per_sample_minmax": 102 | grouped = df.groupby(by=df.index) 103 | min_vals = grouped.transform('min') 104 | return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps) 105 | 106 | else: 107 | raise (NameError(f'Normalize method "{self.norm_type}" not implemented')) 108 | 109 | 110 | def interpolate_missing(y): 111 | """ 112 | Replaces NaN values in pd.Series `y` using linear interpolation 113 | """ 114 | if y.isna().any(): 115 | y = y.interpolate(method='linear', limit_direction='both') 116 | return y 117 | 118 | 119 | def subsample(y, limit=256, factor=2): 120 | """ 121 | If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor 122 | """ 123 | if len(y) > limit: 124 | return y[::factor].reset_index(drop=True) 125 | return y 126 | -------------------------------------------------------------------------------- /data_provider/zeroshot_task.yaml: -------------------------------------------------------------------------------- 1 | task_dataset: 2 | Solar: 3 | task_name: long_term_forecast 4 | dataset_name: solar_10_minutes 5 | dataset: Solar 6 | data: gluonts 7 | root_path: ../dataset/gluonts 8 | seq_len: 128 9 | label_len: 0 10 | pred_len: 64 11 | features: M 12 | embed: timeF 13 | enc_in: 137 14 | dec_in: 137 15 | c_out: 137 16 | 17 | Saugeen_River: 18 | task_name: long_term_forecast 19 | dataset_name: saugeenday 20 | dataset: Saugeen_River 21 | data: gluonts 22 | root_path: ../dataset/gluonts 23 | seq_len: 256 24 | label_len: 0 25 | pred_len: 128 26 | features: M 27 | embed: timeF 28 | enc_in: 1 29 | dec_in: 1 30 | c_out: 1 31 | 32 | Hospital: 33 | task_name: long_term_forecast 34 | dataset_name: hospital 35 | dataset: Hospital 36 | data: gluonts 37 | root_path: ../dataset/gluonts 38 | seq_len: 32 39 | label_len: 0 40 | pred_len: 16 41 | features: M 42 | embed: timeF 43 | enc_in: 767 44 | dec_in: 767 45 | c_out: 767 46 | 47 | Web_Traffic: 48 | task_name: long_term_forecast 49 | dataset_name: kaggle_web_traffic_without_missing 50 | dataset: Web_Traffic 51 | data: gluonts 52 | root_path: ../dataset/gluonts 53 | seq_len: 160 54 | label_len: 0 55 | pred_len: 80 56 | features: M 57 | embed: timeF 58 | enc_in: 500 59 | dec_in: 500 60 | c_out: 500 -------------------------------------------------------------------------------- /download_data_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p dataset 4 | 5 | # check for gdown https://github.com/wkentaro/gdown then install if necessary 6 | if ! command -v gdown &> /dev/null 7 | then 8 | echo "installing gdown, for downloading from google drive" 9 | pip install gdown 10 | fi 11 | 12 | # TimesNet data 13 | # downloads all_datasets.zip and extracts into dataset/ 14 | if [ ! -f dataset/all_datasets.zip ]; then 15 | gdown "https://drive.google.com/file/d/1pmXvqWsfUeXWCMz5fqsP8WLKXR5jxY8z/view?usp=drive_link" --fuzzy -O dataset/all_datasets.zip 16 | unzip dataset/all_datasets.zip -d dataset/ 17 | mv dataset/all_datasets/* dataset/ 18 | rm -rf dataset/all_datasets 19 | fi 20 | 21 | # UAE data 22 | # downloads Multivariate2018_ts.zip then extacts into dataset/UAE/ 23 | if [ ! -f dataset/Multivariate2018_ts.zip ]; then 24 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Archives/Multivariate2018_ts.zip" -O dataset/Multivariate2018_ts.zip 25 | unzip dataset/Multivariate2018_ts.zip -d dataset/ 26 | mv dataset/Multivariate_ts dataset/UAE 27 | fi 28 | 29 | # UCR data 30 | # downloads Univariate2018_ts.zip then extacts into dataset/UCR/ 31 | if [ ! -f dataset/Univariate2018_ts.zip ]; then 32 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Archives/Univariate2018_ts.zip" -O dataset/Univariate2018_ts.zip 33 | unzip dataset/Univariate2018_ts.zip -d dataset/ 34 | mv dataset/Univariate_ts dataset/UCR 35 | fi 36 | 37 | 38 | # Other timeseriesclassification.com datasets: 39 | 40 | # Blink data 41 | if [ ! -f dataset/Blink.zip ]; then 42 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Blink.zip" -O dataset/Blink.zip 43 | unzip dataset/Blink.zip -d dataset/Blink 44 | fi 45 | 46 | # MotionSenseHAR data 47 | if [ ! -f dataset/MotionSenseHAR.zip ]; then 48 | wget "https://www.timeseriesclassification.com/aeon-toolkit/MotionSenseHAR.zip" -O dataset/MotionSenseHAR.zip 49 | unzip dataset/MotionSenseHAR.zip -d dataset/MotionSenseHAR 50 | fi 51 | 52 | # EMOPain data 53 | if [ ! -f dataset/EMOPain.zip ]; then 54 | wget "https://www.timeseriesclassification.com/aeon-toolkit/EMOPain.zip" -O dataset/EMOPain.zip 55 | unzip dataset/EMOPain.zip -d dataset/EMOPain 56 | fi 57 | 58 | 59 | # SharePriceIncreasen data 60 | if [ ! -f dataset/SharePriceIncrease.zip ]; then 61 | wget "https://www.timeseriesclassification.com/aeon-toolkit/SharePriceIncrease.zip" -O dataset/SharePriceIncrease.zip 62 | unzip dataset/SharePriceIncrease.zip -d dataset/SharePriceIncrease 63 | fi 64 | 65 | 66 | # AbnormalHeartbeat data 67 | if [ ! -f dataset/AbnormalHeartbeat.zip ]; then 68 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AbnormalHeartbeat.zip" -O dataset/AbnormalHeartbeat.zip 69 | unzip dataset/AbnormalHeartbeat.zip -d dataset/AbnormalHeartbeat 70 | fi 71 | 72 | # AsphaltObstacles data 73 | if [ ! -f dataset/AsphaltObstacles.zip ]; then 74 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltObstacles.zip" -O dataset/AsphaltObstacles.zip 75 | unzip dataset/AsphaltObstacles.zip -d dataset/AsphaltObstacles 76 | fi 77 | 78 | # AsphaltObstaclesCoordinates data 79 | if [ ! -f dataset/AsphaltObstaclesCoordinates.zip ]; then 80 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltObstaclesCoordinates.zip" -O dataset/AsphaltObstaclesCoordinates.zip 81 | unzip dataset/AsphaltObstaclesCoordinates.zip -d dataset/AsphaltObstaclesCoordinates 82 | fi 83 | 84 | # AsphaltPavementType data 85 | if [ ! -f dataset/AsphaltPavementType.zip ]; then 86 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltPavementType.zip" -O dataset/AsphaltPavementType.zip 87 | unzip dataset/AsphaltPavementType.zip -d dataset/AsphaltPavementType 88 | fi 89 | 90 | # AsphaltPavementTypeCoordinates data 91 | if [ ! -f dataset/AsphaltPavementTypeCoordinates.zip ]; then 92 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltPavementTypeCoordinates.zip" -O dataset/AsphaltPavementTypeCoordinates.zip 93 | unzip dataset/AsphaltPavementTypeCoordinates.zip -d dataset/AsphaltPavementTypeCoordinates 94 | fi 95 | 96 | # AsphaltRegularity data 97 | if [ ! -f dataset/AsphaltRegularity.zip ]; then 98 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltRegularity.zip" -O dataset/AsphaltRegularity.zip 99 | unzip dataset/AsphaltRegularity.zip -d dataset/AsphaltRegularity 100 | fi 101 | 102 | # AsphaltRegularityCoordinates data 103 | if [ ! -f dataset/AsphaltRegularityCoordinates.zip ]; then 104 | wget "https://www.timeseriesclassification.com/aeon-toolkit/AsphaltRegularityCoordinates.zip" -O dataset/AsphaltRegularityCoordinates.zip 105 | unzip dataset/AsphaltRegularityCoordinates.zip -d dataset/AsphaltRegularityCoordinates 106 | fi 107 | 108 | # BinaryHeartbeat data 109 | if [ ! -f dataset/BinaryHeartbeat.zip ]; then 110 | wget "https://www.timeseriesclassification.com/aeon-toolkit/BinaryHeartbeat.zip" -O dataset/BinaryHeartbeat.zip 111 | unzip dataset/BinaryHeartbeat.zip -d dataset/BinaryHeartbeat 112 | fi 113 | 114 | # CatsDogs data 115 | if [ ! -f dataset/CatsDogs.zip ]; then 116 | wget "https://www.timeseriesclassification.com/aeon-toolkit/CatsDogs.zip" -O dataset/CatsDogs.zip 117 | unzip dataset/CatsDogs.zip -d dataset/CatsDogs 118 | fi 119 | 120 | # Colposcopy data 121 | if [ ! -f dataset/Colposcopy.zip ]; then 122 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Colposcopy.zip" -O dataset/Colposcopy.zip 123 | unzip dataset/Colposcopy.zip -d dataset/Colposcopy 124 | fi 125 | 126 | # CounterMovementJump data 127 | if [ ! -f dataset/CounterMovementJump.zip ]; then 128 | wget "https://www.timeseriesclassification.com/aeon-toolkit/CounterMovementJump.zip" -O dataset/CounterMovementJump.zip 129 | unzip dataset/CounterMovementJump.zip -d dataset/CounterMovementJump 130 | fi 131 | 132 | # DucksAndGeese data 133 | if [ ! -f dataset/DucksAndGeese.zip ]; then 134 | wget "https://www.timeseriesclassification.com/aeon-toolkit/DucksAndGeese.zip" -O dataset/DucksAndGeese.zip 135 | unzip dataset/DucksAndGeese.zip -d dataset/DucksAndGeese 136 | fi 137 | 138 | # ElectricDeviceDetection data 139 | if [ ! -f dataset/ElectricDeviceDetection.zip ]; then 140 | wget "https://www.timeseriesclassification.com/aeon-toolkit/ElectricDeviceDetection.zip" -O dataset/ElectricDeviceDetection.zip 141 | unzip dataset/ElectricDeviceDetection.zip -d dataset/ElectricDeviceDetection 142 | fi 143 | 144 | # Epilepsy2 data 145 | if [ ! -f dataset/Epilepsy2.zip ]; then 146 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Epilepsy2.zip" -O dataset/Epilepsy2.zip 147 | unzip dataset/Epilepsy2.zip -d dataset/Epilepsy2 148 | fi 149 | 150 | # EyesOpenShut data 151 | if [ ! -f dataset/EyesOpenShut.zip ]; then 152 | wget "https://www.timeseriesclassification.com/aeon-toolkit/EyesOpenShut.zip" -O dataset/EyesOpenShut.zip 153 | unzip dataset/EyesOpenShut.zip -d dataset/EyesOpenShut 154 | fi 155 | 156 | # FaultDetectionA data 157 | if [ ! -f dataset/FaultDetectionA.zip ]; then 158 | wget "https://www.timeseriesclassification.com/aeon-toolkit/FaultDetectionA.zip" -O dataset/FaultDetectionA.zip 159 | unzip dataset/FaultDetectionA.zip -d dataset/FaultDetectionA 160 | fi 161 | 162 | # FruitFlies data 163 | if [ ! -f dataset/FruitFlies.zip ]; then 164 | wget "https://www.timeseriesclassification.com/aeon-toolkit/FruitFlies.zip" -O dataset/FruitFlies.zip 165 | unzip dataset/FruitFlies.zip -d dataset/FruitFlies 166 | fi 167 | 168 | # InsectSound data 169 | if [ ! -f dataset/InsectSound.zip ]; then 170 | wget "https://www.timeseriesclassification.com/aeon-toolkit/InsectSound.zip" -O dataset/InsectSound.zip 171 | unzip dataset/InsectSound.zip -d dataset/InsectSound 172 | fi 173 | 174 | # KeplerLightCurves data 175 | if [ ! -f dataset/KeplerLightCurves.zip ]; then 176 | wget "https://www.timeseriesclassification.com/aeon-toolkit/KeplerLightCurves.zip" -O dataset/KeplerLightCurves.zip 177 | unzip dataset/KeplerLightCurves.zip -d dataset/KeplerLightCurves 178 | fi 179 | 180 | # MindReading data 181 | if [ ! -f dataset/MindReading.zip ]; then 182 | wget "https://www.timeseriesclassification.com/aeon-toolkit/MindReading.zip" -O dataset/MindReading.zip 183 | unzip dataset/MindReading.zip -d dataset/MindReading 184 | fi 185 | 186 | # MosquitoSound data 187 | if [ ! -f dataset/MosquitoSound.zip ]; then 188 | wget "https://www.timeseriesclassification.com/aeon-toolkit/MosquitoSound.zip" -O dataset/MosquitoSound.zip 189 | unzip dataset/MosquitoSound.zip -d dataset/MosquitoSound 190 | fi 191 | 192 | # NerveDamage data 193 | if [ ! -f dataset/NerveDamage.zip ]; then 194 | wget "https://www.timeseriesclassification.com/aeon-toolkit/NerveDamage.zip" -O dataset/NerveDamage.zip 195 | unzip dataset/NerveDamage.zip -d dataset/NerveDamage 196 | fi 197 | 198 | # RightWhaleCalls data 199 | if [ ! -f dataset/RightWhaleCalls.zip ]; then 200 | wget "https://www.timeseriesclassification.com/aeon-toolkit/RightWhaleCalls.zip" -O dataset/RightWhaleCalls.zip 201 | unzip dataset/RightWhaleCalls.zip -d dataset/RightWhaleCalls 202 | fi 203 | 204 | # Sleep data 205 | if [ ! -f dataset/Sleep.zip ]; then 206 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Sleep.zip" -O dataset/Sleep.zip 207 | unzip dataset/Sleep.zip -d dataset/Sleep 208 | fi 209 | 210 | # Tiselac data 211 | if [ ! -f dataset/Tiselac.zip ]; then 212 | wget "https://www.timeseriesclassification.com/aeon-toolkit/Tiselac.zip" -O dataset/Tiselac.zip 213 | unzip dataset/Tiselac.zip -d dataset/Tiselac 214 | fi 215 | 216 | # UrbanSound data 217 | if [ ! -f dataset/UrbanSound.zip ]; then 218 | wget "https://www.timeseriesclassification.com/aeon-toolkit/UrbanSound.zip" -O dataset/UrbanSound.zip 219 | unzip dataset/UrbanSound.zip -d dataset/UrbanSound 220 | fi 221 | 222 | # WalkingSittingStanding data 223 | if [ ! -f dataset/WalkingSittingStanding.zip ]; then 224 | wget "https://www.timeseriesclassification.com/aeon-toolkit/WalkingSittingStanding.zip" -O dataset/WalkingSittingStanding.zip 225 | unzip dataset/WalkingSittingStanding.zip -d dataset/WalkingSittingStanding 226 | fi 227 | -------------------------------------------------------------------------------- /exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/exp/__init__.py -------------------------------------------------------------------------------- /exp/exp_pretrain.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from utils.tools import cosine_scheduler 3 | from utils.tools import NativeScalerWithGradNormCount as NativeScaler 4 | from utils.losses import UnifiedMaskRecLoss 5 | from utils.dataloader import BalancedDataLoaderIterator 6 | from utils.ddp import is_main_process, get_world_size 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import optim 11 | import torch.distributed as dist 12 | 13 | import os 14 | import time 15 | import warnings 16 | import numpy as np 17 | import yaml 18 | import wandb 19 | import importlib 20 | import sys 21 | 22 | warnings.filterwarnings('ignore') 23 | 24 | def custom_print_decorator(func): 25 | def wrapper(*args, **kwargs): 26 | text = ' '.join(map(str, args)) 27 | if 'file' not in kwargs or kwargs['file'] is None: 28 | sys.stdout.write(text + '\n') 29 | else: 30 | kwargs['file'].write(text + '\n') 31 | 32 | if 'folder' in kwargs and kwargs['folder']: 33 | with open(f'{kwargs["folder"]}/finetune_output.log', 'a') as log_file: 34 | log_file.write(text + '\n') 35 | if 'folder' in kwargs: 36 | del kwargs['folder'] 37 | if 'file' in kwargs: 38 | del kwargs['file'] 39 | return wrapper 40 | 41 | 42 | # replace print to save all print into log files 43 | print = custom_print_decorator(print) 44 | 45 | 46 | def read_task_data_config(config_path): 47 | with open(config_path, 'r') as config_file: 48 | config = yaml.load(config_file, Loader=yaml.FullLoader) 49 | task_dataset_config = config.get('task_dataset', {}) 50 | return task_dataset_config 51 | 52 | 53 | def get_task_data_config_list(task_data_config, default_batch_size=None): 54 | task_data_config_list = [] 55 | 56 | for task_name, task_config in task_data_config.items(): 57 | task_config['max_batch'] = default_batch_size 58 | task_data_config_list.append([task_name, task_config]) 59 | 60 | return task_data_config_list 61 | 62 | 63 | def init_and_merge_datasets(data_loader_list): 64 | dataloader = BalancedDataLoaderIterator(data_loader_list) 65 | train_steps = dataloader.__len__() 66 | 67 | return dataloader, train_steps 68 | 69 | 70 | class Exp_All_Task(object): 71 | def __init__(self, args): 72 | super(Exp_All_Task, self).__init__() 73 | 74 | self.args = args 75 | self.task_data_config = read_task_data_config( 76 | self.args.task_data_config_path) 77 | self.task_data_config_list = get_task_data_config_list( 78 | self.task_data_config, default_batch_size=self.args.batch_size) 79 | device_id = dist.get_rank() % torch.cuda.device_count() 80 | print("this device_id:", device_id) 81 | self.device_id = device_id 82 | 83 | def _build_model(self, ddp=True): 84 | module = importlib.import_module("models."+self.args.model) 85 | model = module.Model( 86 | self.args, self.task_data_config_list, pretrain=True).to(self.device_id) 87 | if ddp: 88 | model = nn.parallel.DistributedDataParallel( 89 | model, device_ids=[self.device_id], find_unused_parameters=True) 90 | return model.to(self.device_id) 91 | 92 | def _get_data(self, flag): 93 | data_set_list = [] 94 | data_loader_list = [] 95 | for task_data_name, task_config in self.task_data_config.items(): 96 | print("loading dataset:", task_data_name, folder=self.path) 97 | if task_config['data'] == 'UEA' and flag == 'val': 98 | # TODO strange that no val set is used for classification. Set to test set for val 99 | flag = 'test' 100 | data_set, data_loader = data_provider( 101 | self.args, task_config, flag, ddp=True) 102 | data_set_list.append(data_set) 103 | data_loader_list.append(data_loader) 104 | return data_set_list, data_loader_list 105 | 106 | def _select_optimizer(self): 107 | eff_batch_size = self.args.batch_size * self.args.acc_it * get_world_size() 108 | real_learning_rate = self.args.learning_rate * eff_batch_size / 32 109 | print("base lr: %.2e" % (self.args.learning_rate * 32 / eff_batch_size)) 110 | print("actual lr: %.2e" % real_learning_rate) 111 | self.real_learning_rate = real_learning_rate 112 | 113 | print("accumulate grad iterations: %d" % self.args.acc_it) 114 | print("effective batch size: %d" % eff_batch_size) 115 | model_optim = optim.Adam(self.model.parameters( 116 | ), lr=real_learning_rate, betas=(0.9, self.args.beta2), weight_decay=self.args.weight_decay, eps=self.args.eps) 117 | return model_optim 118 | 119 | def train(self, setting): 120 | path = os.path.join(self.args.checkpoints, setting) 121 | if not os.path.exists(path) and is_main_process(): 122 | os.makedirs(path) 123 | self.path = path 124 | 125 | torch.cuda.synchronize() 126 | dist.barrier() 127 | 128 | # Data loader 129 | _, train_loader_list = self._get_data(flag='train') 130 | data_loader_cycle, train_steps = init_and_merge_datasets( 131 | train_loader_list) 132 | 133 | # Set up batch size for each task 134 | if self.args.memory_check: 135 | self.memory_check(data_loader_cycle) 136 | torch.cuda.empty_cache() 137 | 138 | torch.cuda.synchronize() 139 | dist.barrier() 140 | 141 | # Model 142 | self.model = self._build_model() 143 | 144 | pytorch_total_params = sum(p.numel() for p in self.model.parameters()) 145 | print("Parameters number {} M".format( 146 | pytorch_total_params/1e6), folder=self.path) 147 | print("{} steps for each epoch".format(train_steps), folder=self.path) 148 | 149 | # Optimizer 150 | model_optim = self._select_optimizer() 151 | lr_schedule = cosine_scheduler( 152 | self.real_learning_rate, 153 | self.args.min_lr, 154 | self.args.train_epochs, train_steps, 155 | warmup_epochs=self.args.warmup_epochs, 156 | ) 157 | 158 | # Loss 159 | criterion = UnifiedMaskRecLoss().to(self.device_id) 160 | scaler = NativeScaler() 161 | 162 | for epoch in range(self.args.train_epochs): 163 | train_loss = self.train_one_epoch( 164 | model_optim, data_loader_cycle, criterion, epoch, train_steps, scaler, lr_schedule) 165 | 166 | print("Epoch: {0}, Steps: {1} | Avg Train Loss: {2:.7f}".format( 167 | epoch + 1, train_steps, train_loss), folder=self.path) 168 | if is_main_process(): 169 | wandb.log({'train_loss_avg': train_loss}) 170 | 171 | if is_main_process(): 172 | save_dict = { 173 | 'student': self.model.state_dict(), 174 | 'optimizer': model_optim.state_dict(), 175 | 'epoch': epoch + 1, 176 | 'args': self.args, 177 | } 178 | 179 | torch.save(save_dict, path + '/' + 'pretrain_checkpoint.pth') 180 | 181 | return self.model 182 | 183 | def train_one_epoch(self, model_optim, data_loader_cycle, criterion, epoch, train_steps, scaler, lr_schedule): 184 | current_device = torch.cuda.current_device() 185 | train_loss_set = [] 186 | 187 | acc_it = self.args.acc_it 188 | max_norm = self.args.clip_grad 189 | min_keep_ratio = self.args.min_keep_ratio 190 | 191 | self.model.train() 192 | epoch_time = time.time() 193 | self.model.zero_grad(set_to_none=True) 194 | loss_sum_display = 0 195 | 196 | for i, (sample_init, task_id) in enumerate(data_loader_cycle): 197 | it = train_steps * epoch + i 198 | for _, param_group in enumerate(model_optim.param_groups): 199 | param_group["lr"] = lr_schedule[it] 200 | 201 | # Get batch data based on the real batch size of each task: avoid OOM for large samples 202 | task_name = self.task_data_config_list[task_id][1]['task_name'] 203 | small_batch_size = self.task_data_config_list[task_id][1]['max_batch'] 204 | sample_list = self.get_multi_source_data( 205 | sample_init, task_name, small_batch_size, min_keep_ratio=min_keep_ratio) 206 | len_sample_list = len(sample_list) 207 | 208 | # Accumulate gradients of mulitple samples 209 | for sample_idx in range(len_sample_list): 210 | sample = sample_list[sample_idx] 211 | x_enc, x_mark_enc, pad_mask = sample 212 | with torch.cuda.amp.autocast(): 213 | model_output = self.model( 214 | x_enc=x_enc, x_mark_enc=x_mark_enc, task_id=task_id, task_name=task_name, enable_mask=True) 215 | loss_dict = criterion(model_output, x_enc, pad_mask) 216 | loss = loss_dict['loss'] 217 | loss /= acc_it 218 | loss /= len_sample_list 219 | if sample_idx < len_sample_list-1: 220 | norm_value = scaler(loss, model_optim, clip_grad=max_norm, 221 | parameters=self.model.parameters(), create_graph=False, update_grad=False) 222 | 223 | loss_display = loss.item()*len_sample_list*acc_it 224 | train_loss_set.append(loss_display) 225 | 226 | norm_value = scaler(loss, model_optim, clip_grad=max_norm, 227 | parameters=self.model.parameters(), create_graph=False, update_grad=((i + 1) % acc_it == 0)) 228 | 229 | if (i+1) % acc_it == 0: 230 | model_optim.zero_grad() 231 | torch.cuda.synchronize() 232 | 233 | loss_sum_display += loss_display 234 | 235 | # release memory to avoid OOM 236 | del sample_init 237 | del sample_list 238 | if torch.cuda.memory_reserved(current_device) > 30*1e9: 239 | torch.cuda.empty_cache() 240 | 241 | if is_main_process(): 242 | wandb_loss_dict = { 243 | 'norm': norm_value if norm_value is not None else 0, 244 | 'train_cls_loss_'+self.task_data_config_list[task_id][0]: loss_dict['cls_loss'].item(), 245 | 'train_mask_loss_'+self.task_data_config_list[task_id][0]: loss_dict['mask_loss'].item(), 246 | 'train_sum_loss_'+self.task_data_config_list[task_id][0]: loss_dict['loss'].item(), 247 | "loss_avg": loss_sum_display/(i+1) 248 | } 249 | wandb.log(wandb_loss_dict) 250 | 251 | if (i + 1) % 50 == 0 and is_main_process(): 252 | print("\titers: {0}, epoch: {1} | lr: {2:.5} | loss_avg: {3} | current_loss: {4} |current data: {5}".format( 253 | i + 1, epoch + 1, lr_schedule[it], loss_sum_display/(i+1), loss.item() * acc_it, task_name), folder=self.path) 254 | 255 | if is_main_process(): 256 | print("Epoch: {} cost time: {}".format( 257 | epoch + 1, time.time() - epoch_time), folder=self.path) 258 | train_loss = np.average(train_loss_set) 259 | 260 | return train_loss 261 | 262 | def get_multi_source_data(self, this_batch, task_name, small_batch_size, min_keep_ratio=None): 263 | """ 264 | Splits the input batch into smaller batches based on the specified small_batch_size. 265 | 266 | Args: 267 | this_batch (tuple): The input batch containing all data of a task. 268 | task_name (str): The name of the task. 269 | small_batch_size (int): The size of the smaller batches to split the data into. 270 | min_keep_ratio (float, optional): The minimum ratio of data to keep in each smaller batch. 271 | 272 | Returns: 273 | list: A list of tuples, where each tuple contains a smaller batch of data, marks, and padding masks. 274 | """ 275 | 276 | def split_tensor(tensor, size): 277 | return [tensor[i:min(i + size, tensor.size(0))] for i in range(0, tensor.size(0), size)] 278 | 279 | if "long_term_forecast" in task_name: 280 | batch_x, _, batch_x_mark, _ = this_batch 281 | batch_x = batch_x.float().to(self.device_id) 282 | batch_x_mark = batch_x_mark.float().to(self.device_id) 283 | batch_x_mark = batch_x_mark.max(dim=-1)[0] 284 | padding_mask = torch.ones( 285 | (batch_x.shape[0], batch_x.shape[1]), dtype=torch.bool).to(self.device_id) 286 | elif "classification" in task_name: 287 | batch_x, _, padding_mask = this_batch 288 | batch_x = batch_x.float().to(self.device_id) 289 | batch_x_mark = padding_mask.float().to(self.device_id) 290 | padding_mask = batch_x_mark.bool().to(self.device_id) 291 | 292 | if min_keep_ratio is not None: 293 | keep_ratios = torch.rand( 294 | 1, device=batch_x.device) * (1.0 - min_keep_ratio) + min_keep_ratio 295 | L = batch_x.shape[1] 296 | len_keeps = (L * keep_ratios).long() 297 | len_keeps = (torch.ceil(len_keeps/self.args.patch_len) 298 | )*self.args.patch_len 299 | len_keeps = len_keeps.int() 300 | 301 | batch_x = batch_x[:, :len_keeps] 302 | batch_x_mark = batch_x_mark[:, :len_keeps] 303 | padding_mask = padding_mask[:, :len_keeps] 304 | 305 | split_batch_x = split_tensor(batch_x, small_batch_size) 306 | split_batch_x_mark = split_tensor(batch_x_mark, small_batch_size) 307 | split_padding_mask = split_tensor(padding_mask, small_batch_size) 308 | 309 | return list(zip(split_batch_x, split_batch_x_mark, split_padding_mask)) 310 | 311 | def memory_check(self, data_loader_cycle, holdout_memory=6): 312 | """ 313 | Checks the memory usage of the model by gradually increasing the batch size until it reaches the maximum batch size that can be supported without running out of memory. 314 | 315 | Args: 316 | data_loader_cycle (DataLoaderCycle): The data loader cycle object. 317 | holdout_memory (int): The amount of memory (in GB) to hold out for other operations. 318 | 319 | Returns: 320 | None 321 | """ 322 | num_elements = holdout_memory * 1024 * 1024 * 1024 // 4 323 | extra_mem = torch.empty( 324 | num_elements, dtype=torch.float32, device=self.device_id) 325 | 326 | model_tmp = self._build_model(ddp=False) 327 | criterion = UnifiedMaskRecLoss().to(self.device_id) 328 | model_tmp.train() 329 | model_tmp.zero_grad(set_to_none=True) 330 | 331 | for data_loader_id in range(data_loader_cycle.num_dataloaders): 332 | batch_size = 1 333 | max_batch_size = 0 334 | torch.cuda.synchronize() 335 | model_tmp.zero_grad(set_to_none=True) 336 | while True: 337 | try: 338 | sample, task_id = data_loader_cycle.generate_fake_samples_for_batch( 339 | data_loader_id, batch_size) 340 | task_name = self.task_data_config_list[task_id][1]['task_name'] 341 | if "long_term_forecast" in task_name: 342 | batch_x, _, batch_x_mark, _ = sample 343 | batch_x = batch_x.float().to(self.device_id) 344 | batch_x_mark = batch_x_mark.float().to(self.device_id) 345 | elif "classification" in task_name: 346 | batch_x, _, batch_x_mark = sample 347 | batch_x = batch_x.float().to(self.device_id) 348 | batch_x_mark = torch.ones( 349 | (batch_x.shape[0], batch_x.shape[1]), dtype=torch.bool).to(self.device_id) 350 | 351 | print(task_id, task_name, 352 | sample[0].shape, "max batch size", max_batch_size) 353 | with torch.cuda.amp.autocast(): 354 | model_output = model_tmp( 355 | x_enc=batch_x, x_mark_enc=batch_x_mark, task_id=task_id, task_name=task_name, enable_mask=True) 356 | loss = 0.0 357 | for each in model_output: 358 | if each is not None: 359 | loss += each.sum() 360 | 361 | loss.backward() 362 | max_batch_size = batch_size 363 | batch_size *= 2 364 | 365 | if max_batch_size >= self.args.batch_size: 366 | print("can support default batchsize:", 367 | self.args.batch_size, max_batch_size) 368 | self.task_data_config_list[task_id][1]['max_batch'] = max_batch_size 369 | self.task_data_config_list[task_id][1]['checkpointing'] = False 370 | break 371 | 372 | except Exception as e: 373 | task_name = self.task_data_config_list[task_id][1]['task_name'] 374 | print(task_id, "max batch size:", max_batch_size) 375 | self.task_data_config_list[task_id][1]['max_batch'] = max_batch_size 376 | print(f"An exception occurred: {e}") 377 | del model_tmp 378 | del criterion 379 | torch.cuda.empty_cache() 380 | model_tmp = self._build_model(ddp=False) 381 | criterion = UnifiedMaskRecLoss().to(self.device_id) 382 | break 383 | del extra_mem 384 | del model_tmp 385 | del criterion 386 | torch.cuda.empty_cache() 387 | print(self.task_data_config_list) 388 | return 389 | -------------------------------------------------------------------------------- /models/UniTS_zeroshot.py: -------------------------------------------------------------------------------- 1 | """ 2 | UniTS 3 | """ 4 | import math 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from timm.layers import Mlp, DropPath 10 | from timm.layers.helpers import to_2tuple 11 | 12 | 13 | def calculate_unfold_output_length(input_length, size, step): 14 | # Calculate the number of windows 15 | num_windows = (input_length - size) // step + 1 16 | return num_windows 17 | 18 | 19 | class CrossAttention(nn.Module): 20 | def __init__( 21 | self, 22 | dim, 23 | num_heads=8, 24 | qkv_bias=False, 25 | qk_norm=False, 26 | attn_drop=0., 27 | proj_drop=0., 28 | norm_layer=nn.LayerNorm, 29 | var_num=None, 30 | ): 31 | super().__init__() 32 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 33 | self.num_heads = num_heads 34 | self.head_dim = dim // num_heads 35 | self.scale = self.head_dim ** -0.5 36 | 37 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 38 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 39 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 40 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 41 | self.attn_drop = nn.Dropout(attn_drop) 42 | self.proj = nn.Linear(dim, dim) 43 | self.proj_drop = nn.Dropout(proj_drop) 44 | if var_num is not None: 45 | self.template = nn.Parameter( 46 | torch.zeros(var_num, dim), requires_grad=True) 47 | torch.nn.init.normal_(self.template, std=.02) 48 | self.var_num = var_num 49 | 50 | def forward(self, x, query=None): 51 | B, N, C = x.shape 52 | if query is not None: 53 | q = self.q(query).reshape( 54 | B, query.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3) 55 | q = self.q_norm(q) 56 | var_num = query.shape[1] 57 | else: 58 | q = self.q(self.template).reshape(1, self.var_num, 59 | self.num_heads, self.head_dim).permute(0, 2, 1, 3) 60 | q = self.q_norm(q) 61 | q = q.repeat(B, 1, 1, 1) 62 | var_num = self.var_num 63 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, 64 | self.head_dim).permute(2, 0, 3, 1, 4) 65 | k, v = kv.unbind(0) 66 | k = self.k_norm(k) 67 | 68 | x = F.scaled_dot_product_attention( 69 | q, k, v, 70 | dropout_p=self.attn_drop.p if self.training else 0., 71 | ) 72 | 73 | x = x.transpose(1, 2).reshape(B, var_num, C) 74 | x = self.proj(x) 75 | x = self.proj_drop(x) 76 | return x 77 | 78 | 79 | class DynamicLinear(nn.Module): 80 | """ 81 | A dynamic linear layer that can interpolate the weight size to support any given input and output feature dimension. 82 | """ 83 | 84 | def __init__(self, in_features=None, out_features=None, fixed_in=0, bias=True): 85 | super(DynamicLinear, self).__init__() 86 | assert fixed_in < in_features, "fixed_in < in_features is required !!!" 87 | self.in_features = in_features 88 | self.out_features = out_features 89 | self.weights = nn.Parameter(torch.Tensor(out_features, in_features)) 90 | self.bias = nn.Parameter(torch.Tensor(out_features)) 91 | self.fixed_in = fixed_in 92 | 93 | self.reset_parameters() 94 | 95 | def reset_parameters(self): 96 | nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) 97 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights) 98 | bound = 1 / math.sqrt(fan_in) 99 | nn.init.uniform_(self.bias, -bound, bound) 100 | 101 | def forward(self, x, out_features): 102 | """ 103 | Forward pass for the dynamic linear layer. 104 | """ 105 | fixed_weights = self.weights[:, :self.fixed_in] 106 | dynamic_weights = self.weights[:, self.fixed_in:] 107 | this_bias = self.bias 108 | in_features = x.shape[-1] 109 | 110 | if in_features != self.weights.size(1) or out_features != self.weights.size(0): 111 | dynamic_weights = F.interpolate(dynamic_weights.unsqueeze(0).unsqueeze(0), size=( 112 | out_features, in_features-self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0) 113 | if self.fixed_in != 0: 114 | fixed_weights = F.interpolate(fixed_weights.unsqueeze(0).unsqueeze(0), size=( 115 | out_features, self.fixed_in), mode='bilinear', align_corners=False).squeeze(0).squeeze(0) 116 | if out_features != self.weights.size(0): 117 | this_bias = F.interpolate(this_bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), size=( 118 | 1, out_features), mode='bilinear', align_corners=False).squeeze(0).squeeze(0).squeeze(0) 119 | return F.linear(x, torch.cat((fixed_weights, dynamic_weights), dim=1), this_bias) 120 | 121 | 122 | class DynamicLinearMlp(nn.Module): 123 | def __init__( 124 | self, 125 | in_features, 126 | hidden_features=None, 127 | out_features=None, 128 | act_layer=nn.GELU, 129 | norm_layer=None, 130 | bias=True, 131 | drop=0., 132 | prefix_token_length=None, 133 | group=1, 134 | ): 135 | super().__init__() 136 | out_features = out_features or in_features 137 | hidden_features = hidden_features or in_features 138 | bias = to_2tuple(bias) 139 | drop_probs = to_2tuple(drop) 140 | 141 | self.fc1 = nn.Conv1d(in_features, hidden_features, 142 | 3, groups=group, bias=bias[0], padding=1) 143 | self.act = act_layer() 144 | self.drop1 = nn.Dropout(drop_probs[0]) 145 | 146 | self.norm = norm_layer( 147 | hidden_features) if norm_layer is not None else nn.Identity() 148 | self.seq_fc = DynamicLinear( 149 | hidden_features//4, hidden_features//4, bias=bias[1], fixed_in=prefix_token_length) 150 | self.prompt_fc = DynamicLinear( 151 | hidden_features//4, prefix_token_length, bias=bias[1], fixed_in=prefix_token_length) 152 | 153 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 154 | self.drop2 = nn.Dropout(drop_probs[1]) 155 | self.hidden_features = hidden_features 156 | self.prefix_token_length = prefix_token_length 157 | 158 | def dynamic_linear(self, x, prefix_seq_len): 159 | x_func = x[:, :, prefix_seq_len:] 160 | x_seq = x[:, :, :prefix_seq_len] 161 | x_seq_out = self.seq_fc( 162 | x_seq, x_seq.shape[-1]-self.prefix_token_length) 163 | x_prompt = self.prompt_fc(x_seq, self.prefix_token_length) 164 | x = torch.cat((x_prompt, x_seq_out, x_func), dim=-1) 165 | return x 166 | 167 | def split_dynamic_linear(self, x, prefix_seq_len): 168 | x1, x2 = x.chunk(2, dim=-2) 169 | x1 = self.dynamic_linear(x1, prefix_seq_len) 170 | return torch.cat((x1, x2), dim=-2) 171 | 172 | def forward(self, x, prefix_seq_len, dim=2): 173 | n, var, l, c = x.shape 174 | x = x.view(-1, l, c) 175 | x = x.transpose(-1, -2) 176 | x = self.fc1(x) 177 | x = self.split_dynamic_linear(x, prefix_seq_len) 178 | x = self.act(x) 179 | x = self.drop1(x) 180 | x = x.transpose(1, 2) 181 | x = self.norm(x) 182 | x = self.fc2(x).view(n, var, l, c) 183 | x = self.drop2(x) 184 | return x 185 | 186 | 187 | class LearnablePositionalEmbedding(nn.Module): 188 | def __init__(self, d_model, max_len=5000): 189 | super(LearnablePositionalEmbedding, self).__init__() 190 | # Compute the positional encodings once in log space. 191 | self.pe = nn.Parameter(torch.zeros( 192 | 1, 1, max_len, d_model), requires_grad=True) 193 | 194 | pe = torch.zeros(max_len, d_model).float() 195 | position = torch.arange(0, max_len).float().unsqueeze(1) 196 | div_term = (torch.arange(0, d_model, 2).float() 197 | * -(math.log(10000.0) / d_model)).exp() 198 | 199 | pe[:, 0::2] = torch.sin(position * div_term) 200 | pe[:, 1::2] = torch.cos(position * div_term) 201 | 202 | pe = pe.unsqueeze(0).unsqueeze(0) 203 | self.pe.data.copy_(pe.float()) 204 | del pe 205 | 206 | def forward(self, x, offset=0): 207 | return self.pe[:, :, offset:offset+x.size(2)] 208 | 209 | 210 | class SeqAttention(nn.Module): 211 | 212 | def __init__( 213 | self, 214 | dim, 215 | num_heads=8, 216 | qkv_bias=False, 217 | qk_norm=False, 218 | attn_drop=0., 219 | proj_drop=0., 220 | norm_layer=nn.LayerNorm, 221 | ): 222 | super().__init__() 223 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 224 | self.num_heads = num_heads 225 | self.head_dim = dim // num_heads 226 | self.scale = self.head_dim ** -0.5 227 | 228 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 229 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 230 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 231 | self.attn_drop = nn.Dropout(attn_drop) 232 | self.proj = nn.Linear(dim, dim) 233 | self.proj_drop = nn.Dropout(proj_drop) 234 | 235 | def forward(self, x, attn_mask=None): 236 | B, N, C = x.shape 237 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, 238 | self.head_dim).permute(2, 0, 3, 1, 4) 239 | q, k, v = qkv.unbind(0) 240 | q, k = self.q_norm(q), self.k_norm(k) 241 | x = F.scaled_dot_product_attention( 242 | q, k, v, # attn_mask=attn_mask, 243 | dropout_p=self.attn_drop.p if self.training else 0., 244 | ) 245 | 246 | x = x.transpose(1, 2).reshape(B, N, C) 247 | x = self.proj(x) 248 | x = self.proj_drop(x) 249 | return x 250 | 251 | 252 | class VarAttention(nn.Module): 253 | 254 | def __init__( 255 | self, 256 | dim, 257 | num_heads=8, 258 | qkv_bias=False, 259 | qk_norm=False, 260 | attn_drop=0., 261 | proj_drop=0., 262 | norm_layer=nn.LayerNorm, 263 | ): 264 | super().__init__() 265 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 266 | self.num_heads = num_heads 267 | self.head_dim = dim // num_heads 268 | self.scale = self.head_dim ** -0.5 269 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 270 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 271 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 272 | self.attn_drop = nn.Dropout(attn_drop) 273 | self.proj = nn.Linear(dim, dim) 274 | self.proj_drop = nn.Dropout(proj_drop) 275 | 276 | def forward(self, x): 277 | B, N, P, C = x.shape 278 | 279 | qkv = self.qkv(x).reshape(B, N, P, 3, self.num_heads, 280 | self.head_dim).permute(3, 0, 2, 4, 1, 5) 281 | q, k, v = qkv.unbind(0) 282 | q, k = self.q_norm(q), self.k_norm(k) 283 | 284 | q = q.mean(dim=1, keepdim=False) 285 | k = k.mean(dim=1, keepdim=False) 286 | v = v.permute(0, 2, 3, 4, 1).reshape(B, self.num_heads, N, -1) 287 | 288 | x = F.scaled_dot_product_attention( 289 | q, k, v, 290 | dropout_p=self.attn_drop.p if self.training else 0., 291 | ) 292 | 293 | x = x.view(B, self.num_heads, N, -1, P).permute(0, 294 | 2, 4, 1, 3).reshape(B, N, P, -1) 295 | x = self.proj(x) 296 | x = self.proj_drop(x) 297 | return x 298 | 299 | 300 | class GateLayer(nn.Module): 301 | def __init__(self, dim, init_values=1e-5, inplace=False): 302 | super().__init__() 303 | self.inplace = inplace 304 | self.gate = nn.Linear(dim, 1) 305 | 306 | def forward(self, x): 307 | gate_value = self.gate(x) 308 | return gate_value.sigmoid() * x 309 | 310 | 311 | class SeqAttBlock(nn.Module): 312 | 313 | def __init__( 314 | self, 315 | dim, 316 | num_heads, 317 | qkv_bias=False, 318 | qk_norm=False, 319 | proj_drop=0., 320 | attn_drop=0., 321 | init_values=None, 322 | drop_path=0., 323 | norm_layer=nn.LayerNorm, 324 | ): 325 | super().__init__() 326 | self.norm1 = norm_layer(dim) 327 | self.attn_seq = SeqAttention( 328 | dim, 329 | num_heads=num_heads, 330 | qkv_bias=qkv_bias, 331 | qk_norm=qk_norm, 332 | attn_drop=attn_drop, 333 | proj_drop=proj_drop, 334 | norm_layer=norm_layer, 335 | ) 336 | 337 | self.ls1 = GateLayer(dim, init_values=init_values) 338 | self.drop_path1 = DropPath( 339 | drop_path) if drop_path > 0. else nn.Identity() 340 | self.proj = nn.Linear(dim, dim) 341 | 342 | def forward(self, x, attn_mask): 343 | x_input = x 344 | x = self.norm1(x) 345 | n_vars, n_seqs = x.shape[1], x.shape[2] 346 | x = torch.reshape( 347 | x, (-1, x.shape[-2], x.shape[-1])) 348 | x = self.attn_seq(x, attn_mask) 349 | x = torch.reshape( 350 | x, (-1, n_vars, n_seqs, x.shape[-1])) 351 | x = x_input + self.drop_path1(self.ls1(x)) 352 | return x 353 | 354 | 355 | class VarAttBlock(nn.Module): 356 | 357 | def __init__( 358 | self, 359 | dim, 360 | num_heads, 361 | qkv_bias=False, 362 | qk_norm=False, 363 | proj_drop=0., 364 | attn_drop=0., 365 | init_values=None, 366 | drop_path=0., 367 | norm_layer=nn.LayerNorm, 368 | ): 369 | super().__init__() 370 | self.norm1 = norm_layer(dim) 371 | self.attn_var = VarAttention( 372 | dim, 373 | num_heads=num_heads, 374 | qkv_bias=qkv_bias, 375 | qk_norm=qk_norm, 376 | attn_drop=attn_drop, 377 | proj_drop=proj_drop, 378 | norm_layer=norm_layer, 379 | ) 380 | self.ls1 = GateLayer(dim, init_values=init_values) 381 | self.drop_path1 = DropPath( 382 | drop_path) if drop_path > 0. else nn.Identity() 383 | self.proj = nn.Linear(dim, dim) 384 | 385 | def forward(self, x): 386 | x = x + self.drop_path1(self.ls1(self.attn_var(self.norm1(x)))) 387 | return x 388 | 389 | 390 | class MLPBlock(nn.Module): 391 | 392 | def __init__( 393 | self, 394 | dim, 395 | mlp_ratio=4., 396 | proj_drop=0., 397 | init_values=None, 398 | drop_path=0., 399 | act_layer=nn.GELU, 400 | norm_layer=nn.LayerNorm, 401 | mlp_layer=None, 402 | prefix_token_length=0, 403 | ): 404 | super().__init__() 405 | self.norm2 = norm_layer(dim) 406 | if mlp_layer is DynamicLinearMlp: 407 | self.mlp = mlp_layer( 408 | in_features=dim, 409 | hidden_features=int(dim * mlp_ratio), 410 | act_layer=act_layer, 411 | drop=proj_drop, 412 | prefix_token_length=prefix_token_length, 413 | ) 414 | else: 415 | self.mlp = mlp_layer( 416 | in_features=dim, 417 | hidden_features=int(dim * mlp_ratio), 418 | act_layer=act_layer, 419 | drop=proj_drop, 420 | ) 421 | self.ls2 = GateLayer(dim, init_values=init_values) 422 | self.drop_path2 = DropPath( 423 | drop_path) if drop_path > 0. else nn.Identity() 424 | 425 | def forward(self, x, prefix_seq_len=None): 426 | if prefix_seq_len is not None: 427 | x = x + \ 428 | self.drop_path2( 429 | self.ls2(self.mlp(self.norm2(x), prefix_seq_len=prefix_seq_len))) 430 | else: 431 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 432 | return x 433 | 434 | 435 | class BasicBlock(nn.Module): 436 | def __init__( 437 | self, 438 | dim, 439 | num_heads, 440 | mlp_ratio=8., 441 | qkv_bias=False, 442 | qk_norm=False, 443 | proj_drop=0., 444 | attn_drop=0., 445 | init_values=None, 446 | drop_path=0., 447 | act_layer=nn.GELU, 448 | norm_layer=nn.LayerNorm, 449 | prefix_token_length=0, 450 | ): 451 | super().__init__() 452 | self.seq_att_block = SeqAttBlock(dim=dim, num_heads=num_heads, 453 | qkv_bias=qkv_bias, qk_norm=qk_norm, 454 | attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop, 455 | drop_path=drop_path, norm_layer=norm_layer) 456 | 457 | self.var_att_block = VarAttBlock(dim=dim, num_heads=num_heads, 458 | qkv_bias=qkv_bias, qk_norm=qk_norm, 459 | attn_drop=attn_drop, init_values=init_values, proj_drop=proj_drop, 460 | drop_path=drop_path, norm_layer=norm_layer) 461 | 462 | self.dynamic_mlp = MLPBlock(dim=dim, mlp_ratio=mlp_ratio, mlp_layer=DynamicLinearMlp, 463 | proj_drop=proj_drop, init_values=init_values, drop_path=drop_path, 464 | act_layer=act_layer, norm_layer=norm_layer, 465 | prefix_token_length=prefix_token_length) 466 | 467 | def forward(self, x, prefix_seq_len, attn_mask): 468 | x = self.seq_att_block(x, attn_mask) 469 | x = self.var_att_block(x) 470 | x = self.dynamic_mlp(x, prefix_seq_len=prefix_seq_len) 471 | return x 472 | 473 | 474 | class PatchEmbedding(nn.Module): 475 | def __init__(self, d_model, patch_len, stride, padding, dropout): 476 | super(PatchEmbedding, self).__init__() 477 | # Patching 478 | self.patch_len = patch_len 479 | self.stride = stride 480 | assert self.patch_len == self.stride, "non-overlap" 481 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False) 482 | self.dropout = nn.Dropout(dropout) 483 | 484 | def forward(self, x): 485 | n_vars = x.shape[1] 486 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 487 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) 488 | x = self.value_embedding(x) 489 | return self.dropout(x), n_vars 490 | 491 | 492 | class CLSHead(nn.Module): 493 | def __init__(self, d_model, head_dropout=0): 494 | super().__init__() 495 | d_mid = d_model 496 | self.proj_in = nn.Linear(d_model, d_mid) 497 | self.cross_att = CrossAttention(d_mid) 498 | 499 | self.mlp = MLPBlock(dim=d_mid, mlp_ratio=8, mlp_layer=Mlp, 500 | proj_drop=head_dropout, init_values=None, drop_path=0.0, 501 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 502 | prefix_token_length=None) 503 | 504 | def forward(self, x, category_token=None, return_feature=False): 505 | x = self.proj_in(x) 506 | B, V, L, C = x.shape 507 | x = x.view(-1, L, C) 508 | cls_token = x[:, -1:] 509 | cls_token = self.cross_att(x, query=cls_token) 510 | cls_token = cls_token.reshape(B, V, -1, C) 511 | 512 | cls_token = self.mlp(cls_token) 513 | if return_feature: 514 | return cls_token 515 | m = category_token.shape[2] 516 | cls_token = cls_token.expand(B, V, m, C) 517 | distance = torch.einsum('nvkc,nvmc->nvm', cls_token, category_token) 518 | 519 | distance = distance.mean(dim=1) 520 | return distance 521 | 522 | 523 | class ForecastHead(nn.Module): 524 | def __init__(self, d_model, patch_len, stride, pad, head_dropout=0, prefix_token_length=None): 525 | super().__init__() 526 | d_mid = d_model 527 | self.proj_in = nn.Linear(d_model, d_mid) 528 | self.mlp = Mlp( 529 | in_features=d_model, 530 | hidden_features=int(d_model * 4), 531 | act_layer=nn.GELU, 532 | drop=head_dropout, 533 | ) 534 | self.proj_out = nn.Linear(d_model, patch_len) 535 | self.pad = pad 536 | self.patch_len = patch_len 537 | self.stride = stride 538 | self.pos_proj = DynamicLinear( 539 | in_features=128, out_features=128, fixed_in=prefix_token_length) 540 | 541 | def forward(self, x_full, pred_len, token_len): 542 | x_full = self.proj_in(x_full) 543 | x_pred = x_full[:, :, -token_len:] 544 | x = x_full.transpose(-1, -2) 545 | x = self.pos_proj(x, token_len) 546 | x = x.transpose(-1, -2) 547 | x = x + x_pred 548 | x = self.mlp(x) 549 | x = self.proj_out(x) 550 | 551 | bs, n_vars = x.shape[0], x.shape[1] 552 | x = x.reshape(-1, x.shape[-2], x.shape[-1]) 553 | x = x.permute(0, 2, 1) 554 | x = torch.nn.functional.fold(x, output_size=( 555 | pred_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1)) 556 | x = x.squeeze(dim=-1) 557 | x = x.reshape(bs, n_vars, -1) 558 | x = x.permute(0, 2, 1) 559 | return x 560 | 561 | 562 | class Model(nn.Module): 563 | """ 564 | UniTS: Building a Unified Time Series Model 565 | """ 566 | 567 | def __init__(self, args, configs_list, pretrain=False): 568 | super().__init__() 569 | 570 | if pretrain: 571 | self.right_prob = args.right_prob 572 | self.min_mask_ratio = args.min_mask_ratio 573 | self.max_mask_ratio = args.max_mask_ratio 574 | 575 | # Tokens settings 576 | self.num_task = len(configs_list) 577 | self.prompt_token =nn.Parameter(torch.zeros(1, 1, args.prompt_num, args.d_model)) 578 | torch.nn.init.normal_(self.prompt_token, std=.02) 579 | self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, args.d_model)) 580 | self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, args.d_model)) 581 | torch.nn.init.normal_(self.cls_token, std=.02) 582 | self.category_tokens = nn.ParameterDict({}) 583 | 584 | for i in range(self.num_task): 585 | task_data_name = configs_list[i][0] 586 | if configs_list[i][1]['task_name'] == 'classification': 587 | self.category_tokens[task_data_name] = torch.zeros( 588 | 1, configs_list[i][1]['enc_in'], configs_list[i][1]['num_class'], args.d_model) 589 | torch.nn.init.normal_( 590 | self.category_tokens[task_data_name], std=.02) 591 | 592 | self.cls_nums = {} 593 | for i in range(self.num_task): 594 | task_data_name = configs_list[i][0] 595 | if configs_list[i][1]['task_name'] == 'classification': 596 | self.cls_nums[task_data_name] = configs_list[i][1]['num_class'] 597 | elif configs_list[i][1]['task_name'] == 'long_term_forecast': 598 | remainder = configs_list[i][1]['seq_len'] % args.patch_len 599 | if remainder == 0: 600 | padding = 0 601 | else: 602 | padding = args.patch_len - remainder 603 | input_token_len = calculate_unfold_output_length( 604 | configs_list[i][1]['seq_len']+padding, args.stride, args.patch_len) 605 | input_pad = args.stride * \ 606 | (input_token_len - 1) + args.patch_len - \ 607 | configs_list[i][1]['seq_len'] 608 | pred_token_len = calculate_unfold_output_length( 609 | configs_list[i][1]['pred_len']-input_pad, args.stride, args.patch_len) 610 | real_len = configs_list[i][1]['seq_len'] + \ 611 | configs_list[i][1]['pred_len'] 612 | self.cls_nums[task_data_name] = [pred_token_len, 613 | configs_list[i][1]['pred_len'], real_len] 614 | 615 | self.configs_list = configs_list 616 | 617 | ### model settings ### 618 | self.prompt_num = args.prompt_num 619 | self.stride = args.stride 620 | self.pad = args.stride 621 | self.patch_len = args.patch_len 622 | 623 | # input processing 624 | self.patch_embeddings = PatchEmbedding( 625 | args.d_model, args.patch_len, args.stride, args.stride, args.dropout) 626 | self.position_embedding = LearnablePositionalEmbedding(args.d_model) 627 | self.prompt2forecat = DynamicLinear(128, 128, fixed_in=args.prompt_num) 628 | 629 | # basic blocks 630 | self.block_num = args.e_layers 631 | self.blocks = nn.ModuleList( 632 | [BasicBlock(dim=args.d_model, num_heads=args.n_heads, qkv_bias=False, qk_norm=False, 633 | mlp_ratio=8., proj_drop=args.dropout, attn_drop=0., drop_path=0., 634 | init_values=None, prefix_token_length=args.prompt_num) for l in range(args.e_layers)] 635 | ) 636 | 637 | # output processing 638 | self.cls_head = CLSHead(args.d_model, head_dropout=args.dropout) 639 | self.forecast_head = ForecastHead( 640 | args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=args.prompt_num, head_dropout=args.dropout) 641 | if pretrain: 642 | self.pretrain_head = ForecastHead( 643 | args.d_model, args.patch_len, args.stride, args.stride, prefix_token_length=1, head_dropout=args.dropout) 644 | 645 | def tokenize(self, x, mask=None): 646 | # Normalization from Non-stationary Transformer 647 | means = x.mean(1, keepdim=True).detach() 648 | x = x - means 649 | if mask is not None: 650 | x = x.masked_fill(mask == 0, 0) 651 | stdev = torch.sqrt(torch.sum(x * x, dim=1) / 652 | torch.sum(mask == 1, dim=1) + 1e-5) 653 | stdev = stdev.unsqueeze(dim=1) 654 | else: 655 | stdev = torch.sqrt( 656 | torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5) 657 | x /= stdev 658 | x = x.permute(0, 2, 1) 659 | remainder = x.shape[2] % self.patch_len 660 | if remainder != 0: 661 | padding = self.patch_len - remainder 662 | x = F.pad(x, (0, padding)) 663 | else: 664 | padding = 0 665 | x, n_vars = self.patch_embeddings(x) 666 | return x, means, stdev, n_vars, padding 667 | 668 | def prepare_prompt(self, x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name=None, mask=None): 669 | x = torch.reshape( 670 | x, (-1, n_vars, x.shape[-2], x.shape[-1])) 671 | # append prompt tokens 672 | this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1) 673 | 674 | if task_name == 'forecast': 675 | this_mask_prompt = task_prompt.repeat( 676 | x.shape[0], 1, task_prompt_num, 1) 677 | init_full_input = torch.cat( 678 | (this_prompt, x, this_mask_prompt), dim=-2) 679 | init_mask_prompt = self.prompt2forecat(init_full_input.transpose( 680 | -1, -2), init_full_input.shape[2]-prefix_prompt.shape[2]).transpose(-1, -2) 681 | this_function_prompt = init_mask_prompt[:, :, -task_prompt_num:] 682 | x = torch.cat((this_prompt, x, this_function_prompt), dim=2) 683 | x[:, :, self.prompt_num:] = x[:, :, self.prompt_num:] + \ 684 | self.position_embedding(x[:, :, self.prompt_num:]) 685 | elif task_name == 'classification': 686 | this_function_prompt = task_prompt.repeat(x.shape[0], 1, 1, 1) 687 | x = x + self.position_embedding(x) 688 | x = torch.cat((this_prompt, x, this_function_prompt), dim=2) 689 | elif task_name == 'imputation': 690 | # fill the masked parts with mask tokens 691 | # for imputation, masked is 0, unmasked is 1, so here to reverse mask 692 | mask = 1-mask 693 | mask = mask.permute(0, 2, 1) 694 | mask = self.mark2token(mask) 695 | mask_repeat = mask.unsqueeze(dim=-1) 696 | 697 | mask_token = task_prompt 698 | mask_repeat = mask_repeat.repeat(1, 1, 1, x.shape[-1]) 699 | x = x * (1-mask_repeat) + mask_token * mask_repeat 700 | 701 | init_full_input = torch.cat((this_prompt, x), dim=-2) 702 | init_mask_prompt = self.prompt2forecat( 703 | init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2) 704 | # keep the unmasked tokens and fill the masked ones with init_mask_prompt. 705 | x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat 706 | x = x + self.position_embedding(x) 707 | x = torch.cat((this_prompt, x), dim=2) 708 | elif task_name == 'anomaly_detection': 709 | x = x + self.position_embedding(x) 710 | x = torch.cat((this_prompt, x), dim=2) 711 | 712 | return x 713 | 714 | def mark2token(self, x_mark): 715 | x_mark = x_mark.unfold( 716 | dimension=-1, size=self.patch_len, step=self.stride) 717 | x_mark = x_mark.mean(dim=-1) 718 | x_mark = (x_mark > 0).float() 719 | return x_mark 720 | 721 | def backbone(self, x, prefix_len, seq_len): 722 | attn_mask = None 723 | for block in self.blocks: 724 | x = block(x, prefix_seq_len=prefix_len + 725 | seq_len, attn_mask=attn_mask) 726 | return x 727 | 728 | def forecast(self, x, x_mark, task_id): 729 | task_data_name = self.configs_list[task_id][0] 730 | task_prompt_num = self.cls_nums[task_data_name][0] 731 | task_seq_num = self.cls_nums[task_data_name][1] 732 | real_seq_len = self.cls_nums[task_data_name][2] 733 | x, means, stdev, n_vars, _ = self.tokenize(x) 734 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1) 735 | task_prompt = self.mask_token.repeat(1,n_vars,1,1) 736 | 737 | x = self.prepare_prompt( 738 | x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='forecast') 739 | 740 | seq_token_len = x.shape[-2]-prefix_prompt.shape[2] 741 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len) 742 | 743 | x = self.forecast_head( 744 | x, real_seq_len, seq_token_len) 745 | x = x[:, -task_seq_num:] 746 | 747 | # De-Normalization from Non-stationary Transformer 748 | x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1)) 749 | x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1)) 750 | 751 | return x 752 | 753 | def classification(self, x, x_mark, task_id): 754 | task_data_name = self.configs_list[task_id][0] 755 | task_prompt_num = 1 756 | category_token = self.category_tokens[task_data_name] 757 | 758 | x, means, stdev, n_vars, _ = self.tokenize(x) 759 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1) 760 | task_prompt = self.cls_token.repeat(1,n_vars,1,1) 761 | 762 | seq_len = x.shape[-2] 763 | 764 | x = self.prepare_prompt( 765 | x, n_vars, prefix_prompt, task_prompt, task_prompt_num, task_name='classification') 766 | 767 | x = self.backbone(x, prefix_prompt.shape[2], seq_len) 768 | 769 | x = self.cls_head(x, category_token) 770 | 771 | return x 772 | 773 | def imputation(self, x, x_mark, mask, task_id): 774 | 775 | seq_len = x.shape[1] 776 | x, means, stdev, n_vars, padding = self.tokenize(x, mask) 777 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1) 778 | task_prompt = self.mask_token.repeat(1,n_vars,1,1) 779 | 780 | x = self.prepare_prompt( 781 | x, n_vars, prefix_prompt, task_prompt, None, mask=mask, task_name='imputation') 782 | seq_token_len = x.shape[-2]-prefix_prompt.shape[2] 783 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len) 784 | 785 | x = self.forecast_head( 786 | x, seq_len+padding, seq_token_len) 787 | x = x[:, :seq_len] 788 | 789 | # De-Normalization from Non-stationary Transformer 790 | x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1)) 791 | x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1)) 792 | 793 | return x 794 | 795 | def anomaly_detection(self, x, x_mark, task_id): 796 | seq_len = x.shape[1] 797 | x, means, stdev, n_vars, padding = self.tokenize(x) 798 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1) 799 | 800 | x = self.prepare_prompt(x, n_vars, prefix_prompt, 801 | None, None, task_name='anomaly_detection') 802 | seq_token_len = x.shape[-2]-prefix_prompt.shape[2] 803 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len) 804 | 805 | x = self.forecast_head( 806 | x, seq_len+padding, seq_token_len) 807 | x = x[:, :seq_len] 808 | 809 | # De-Normalization from Non-stationary Transformer 810 | x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1)) 811 | x = x + (means[:, 0, :].unsqueeze(1).repeat(1, x.shape[1], 1)) 812 | 813 | return x 814 | 815 | def random_masking(self, x, min_mask_ratio, max_mask_ratio): 816 | """ 817 | Perform per-sample random masking. 818 | """ 819 | N, V, L, D = x.shape # batch, var, length, dim 820 | 821 | # Calculate mask ratios and lengths to keep for each sample in the batch 822 | mask_ratios = torch.rand(N, device=x.device) * \ 823 | (max_mask_ratio - min_mask_ratio) + min_mask_ratio 824 | len_keeps = (L * (1 - mask_ratios)).long() 825 | 826 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 827 | 828 | # sort noise for each sample 829 | # ascend: small is keep, large is remove 830 | ids_shuffle = torch.argsort(noise, dim=1) 831 | ids_restore = torch.argsort(ids_shuffle, dim=1) 832 | 833 | # generate the binary mask: 0 is keep, 1 is remove 834 | mask = torch.ones([N, L], device=x.device) 835 | 836 | # Create a range tensor and compare with len_keeps for mask generation 837 | range_tensor = torch.arange(L, device=x.device).expand(N, L) 838 | mask = (range_tensor >= len_keeps.unsqueeze(1)) 839 | 840 | # unshuffle to get the binary mask 841 | mask = torch.gather(mask, dim=1, index=ids_restore) 842 | mask = mask.float() 843 | 844 | return mask 845 | 846 | def right_masking(self, x, min_mask_ratio, max_mask_ratio): 847 | N, V, L, D = x.shape # batch, var, length, dim 848 | 849 | # Randomly choose a mask ratio for each sample within the specified range 850 | mask_ratios = torch.rand(N, device=x.device) * \ 851 | (max_mask_ratio - min_mask_ratio) + min_mask_ratio 852 | len_keeps = (L * (1 - mask_ratios)).long() 853 | 854 | # Binary mask creation without a for loop 855 | len_keeps_matrix = len_keeps.unsqueeze(1).expand(N, L) 856 | indices = torch.arange(L, device=x.device).expand_as(len_keeps_matrix) 857 | mask = indices >= len_keeps_matrix 858 | mask = mask.float() 859 | 860 | return mask 861 | 862 | def choose_masking(self, x, right_prob, min_mask_ratio, max_mask_ratio): 863 | # Generate a random number to decide which masking function to use 864 | if torch.rand(1).item() > right_prob: 865 | return self.random_masking(x, min_mask_ratio, max_mask_ratio) 866 | else: 867 | return self.right_masking(x, min_mask_ratio, max_mask_ratio) 868 | 869 | def get_mask_seq(self, mask, seq_len): 870 | mask_seq = mask.unsqueeze(dim=-1).repeat(1, 1, self.patch_len) 871 | mask_seq = mask_seq.permute(0, 2, 1) 872 | mask_seq = mask_seq.masked_fill(mask_seq == 0, -1e9) 873 | # Fold operation 874 | mask_seq = torch.nn.functional.fold(mask_seq, output_size=( 875 | seq_len, 1), kernel_size=(self.patch_len, 1), stride=(self.stride, 1)) 876 | # Apply threshold to bring back to 0/1 values 877 | mask_seq = (mask_seq > 0).float() 878 | mask_seq = mask_seq.squeeze(dim=-1).squeeze(dim=1) 879 | return mask_seq 880 | 881 | def pretraining(self, x, x_mark, task_id, enable_mask=False): 882 | seq_len = x.shape[1] 883 | x, means, stdev, n_vars, padding = self.tokenize(x) 884 | seq_token_len = x.shape[-2] 885 | prefix_prompt = self.prompt_token.repeat(1,n_vars,1,1) 886 | mask_token = self.mask_token.repeat(1,n_vars,1,1) 887 | cls_token = self.cls_token.repeat(1,n_vars,1,1) 888 | 889 | # append prompt tokens 890 | x = torch.reshape( 891 | x, (-1, n_vars, x.shape[-2], x.shape[-1])) 892 | # prepare prompts 893 | this_prompt = prefix_prompt.repeat(x.shape[0], 1, 1, 1) 894 | 895 | if enable_mask: 896 | mask = self.choose_masking(x, self.right_prob, 897 | self.min_mask_ratio, self.max_mask_ratio) 898 | mask_repeat = mask.unsqueeze(dim=1).unsqueeze(dim=-1) 899 | mask_repeat = mask_repeat.repeat(1, x.shape[1], 1, x.shape[-1]) 900 | x = x * (1-mask_repeat) + mask_token * mask_repeat # todo 901 | 902 | init_full_input = torch.cat((this_prompt, x), dim=-2) 903 | init_mask_prompt = self.prompt2forecat( 904 | init_full_input.transpose(-1, -2), x.shape[2]).transpose(-1, -2) 905 | # keep the unmasked tokens and fill the masked ones with init_mask_prompt. 906 | x = x * (1-mask_repeat) + init_mask_prompt * mask_repeat 907 | x = x + self.position_embedding(x) 908 | mask_seq = self.get_mask_seq(mask, seq_len+padding) 909 | mask_seq = mask_seq[:, :seq_len] 910 | this_function_prompt = cls_token.repeat(x.shape[0], 1, 1, 1) 911 | x = torch.cat((this_prompt, x, this_function_prompt), dim=2) 912 | 913 | x = self.backbone(x, prefix_prompt.shape[2], seq_token_len) 914 | 915 | if enable_mask: 916 | mask_dec_out = self.forecast_head( 917 | x[:, :, :-1], seq_len+padding, seq_token_len) 918 | mask_dec_out = mask_dec_out[:, :seq_len] 919 | # De-Normalization from Non-stationary Transformer 920 | mask_dec_out = mask_dec_out * \ 921 | (stdev[:, 0, :].unsqueeze(1).repeat( 922 | 1, mask_dec_out.shape[1], 1)) 923 | mask_dec_out = mask_dec_out + \ 924 | (means[:, 0, :].unsqueeze(1).repeat( 925 | 1, mask_dec_out.shape[1], 1)) 926 | cls_dec_out = self.cls_head(x, return_feature=True) 927 | # detach grad of the forecasting on tokens 928 | fused_dec_out = torch.cat( 929 | (cls_dec_out, x[:, :, self.prompt_num:-1].detach()), dim=2) 930 | cls_dec_out = self.pretrain_head( 931 | fused_dec_out, seq_len+padding, seq_token_len) 932 | cls_dec_out = cls_dec_out[:, :seq_len] 933 | cls_dec_out = cls_dec_out * \ 934 | (stdev[:, 0, :].unsqueeze(1).repeat( 935 | 1, cls_dec_out.shape[1], 1)) 936 | cls_dec_out = cls_dec_out + \ 937 | (means[:, 0, :].unsqueeze(1).repeat( 938 | 1, cls_dec_out.shape[1], 1)) 939 | 940 | return cls_dec_out, mask_dec_out, mask_seq 941 | else: 942 | return cls_dec_out 943 | 944 | def forward(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None, 945 | mask=None, task_id=None, task_name=None, enable_mask=None): 946 | if task_name == 'long_term_forecast' or task_name == 'short_term_forecast': 947 | dec_out = self.forecast(x_enc, x_mark_enc, task_id) 948 | return dec_out # [B, L, D] 949 | if task_name == 'imputation': 950 | dec_out = self.imputation( 951 | x_enc, x_mark_enc, mask, task_id) 952 | return dec_out # [B, L, D] 953 | if task_name == 'anomaly_detection': 954 | dec_out = self.anomaly_detection(x_enc, x_mark_enc, task_id) 955 | return dec_out # [B, L, D] 956 | if task_name == 'classification': 957 | dec_out = self.classification(x_enc, x_mark_enc, task_id) 958 | return dec_out # [B, N] 959 | if 'pretrain' in task_name: 960 | dec_out = self.pretraining(x_enc, x_mark_enc, task_id, 961 | enable_mask=enable_mask) 962 | return dec_out 963 | return None 964 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/models/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.0 2 | matplotlib==3.7.0 3 | numpy==1.23.5 4 | pandas==1.5.3 5 | patool==1.12 6 | reformer-pytorch==1.4.4 7 | scikit-learn==1.2.2 8 | scipy==1.10.1 9 | sktime==0.16.1 10 | sympy==1.11.1 11 | tqdm==4.64.1 12 | pyyaml 13 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from exp.exp_sup import Exp_All_Task as Exp_All_Task_SUP 4 | import random 5 | import numpy as np 6 | import wandb 7 | from utils.ddp import is_main_process, init_distributed_mode 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='UniTS supervised training') 12 | 13 | # basic config 14 | parser.add_argument('--task_name', type=str, required=False, default='ALL_task', 15 | help='task name') 16 | parser.add_argument('--is_training', type=int, 17 | required=True, default=1, help='status') 18 | parser.add_argument('--model_id', type=str, required=True, 19 | default='test', help='model id') 20 | parser.add_argument('--model', type=str, required=True, default='UniTS', 21 | help='model name') 22 | 23 | # data loader 24 | parser.add_argument('--data', type=str, required=False, 25 | default='All', help='dataset type') 26 | parser.add_argument('--features', type=str, default='M', 27 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 28 | parser.add_argument('--target', type=str, default='OT', 29 | help='target feature in S or MS task') 30 | parser.add_argument('--freq', type=str, default='h', 31 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 32 | parser.add_argument('--task_data_config_path', type=str, 33 | default='exp/all_task.yaml', help='root path of the task and data yaml file') 34 | parser.add_argument('--subsample_pct', type=float, 35 | default=None, help='subsample percent') 36 | 37 | # ddp 38 | parser.add_argument('--local-rank', type=int, help='local rank') 39 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 40 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 41 | parser.add_argument('--num_workers', type=int, default=0, 42 | help='data loader num workers') 43 | parser.add_argument("--memory_check", action="store_true", default=True) 44 | parser.add_argument("--large_model", action="store_true", default=True) 45 | 46 | # optimization 47 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 48 | parser.add_argument('--train_epochs', type=int, 49 | default=10, help='train epochs') 50 | parser.add_argument("--prompt_tune_epoch", type=int, default=0) 51 | parser.add_argument('--warmup_epochs', type=int, 52 | default=0, help='warmup epochs') 53 | parser.add_argument('--batch_size', type=int, default=32, 54 | help='batch size of train input data') 55 | parser.add_argument('--acc_it', type=int, default=1, 56 | help='acc iteration to enlarge batch size') 57 | parser.add_argument('--learning_rate', type=float, 58 | default=0.0001, help='optimizer learning rate') 59 | parser.add_argument('--min_lr', type=float, default=None, 60 | help='optimizer min learning rate') 61 | parser.add_argument('--weight_decay', type=float, 62 | default=0.0, help='optimizer weight decay') 63 | parser.add_argument('--layer_decay', type=float, 64 | default=None, help='optimizer layer decay') 65 | parser.add_argument('--des', type=str, default='test', 66 | help='exp description') 67 | parser.add_argument('--lradj', type=str, 68 | default='supervised', help='adjust learning rate') 69 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 70 | help='Clip gradient norm (default: None, no clipping)') 71 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 72 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', 73 | help='save location of model checkpoints') 74 | parser.add_argument('--pretrained_weight', type=str, default=None, 75 | help='location of pretrained model checkpoints') 76 | parser.add_argument('--debug', type=str, 77 | default='enabled', help='disabled') 78 | parser.add_argument('--project_name', type=str, 79 | default='tsfm-multitask', help='wandb project name') 80 | 81 | # model settings 82 | parser.add_argument('--d_model', type=int, default=512, 83 | help='dimension of model') 84 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') 85 | parser.add_argument('--e_layers', type=int, default=2, 86 | help='num of encoder layers') 87 | parser.add_argument("--share_embedding", 88 | action="store_true", default=False) 89 | parser.add_argument("--patch_len", type=int, default=16) 90 | parser.add_argument("--stride", type=int, default=8) 91 | parser.add_argument("--prompt_num", type=int, default=5) 92 | parser.add_argument('--fix_seed', type=int, default=None, help='seed') 93 | 94 | # task related settings 95 | # forecasting task 96 | parser.add_argument('--inverse', action='store_true', 97 | help='inverse output data', default=False) 98 | 99 | # inputation task 100 | parser.add_argument('--mask_rate', type=float, 101 | default=0.25, help='mask ratio') 102 | 103 | # anomaly detection task 104 | parser.add_argument('--anomaly_ratio', type=float, 105 | default=1.0, help='prior anomaly ratio (%)') 106 | 107 | # zero-shot-forecast-new-length 108 | parser.add_argument("--offset", type=int, default=0) 109 | parser.add_argument("--max_offset", type=int, default=0) 110 | parser.add_argument('--zero_shot_forecasting_new_length', 111 | type=str, default=None, help='unify') 112 | 113 | args = parser.parse_args() 114 | init_distributed_mode(args) 115 | if args.fix_seed is not None: 116 | random.seed(args.fix_seed) 117 | torch.manual_seed(args.fix_seed) 118 | np.random.seed(args.fix_seed) 119 | 120 | print('Args in experiment:') 121 | print(args) 122 | exp_name = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}'.format( 123 | args.task_name, 124 | args.model_id, 125 | args.model, 126 | args.data, 127 | args.features, 128 | args.d_model, 129 | args.e_layers, 130 | args.des) 131 | 132 | if int(args.prompt_tune_epoch) != 0: 133 | exp_name = 'Ptune'+str(args.prompt_tune_epoch)+'_'+exp_name 134 | print(exp_name) 135 | 136 | if is_main_process(): 137 | wandb.init( 138 | name=exp_name, 139 | # set the wandb project where this run will be logged 140 | project=args.project_name, 141 | # track hyperparameters and run metadata 142 | config=args, 143 | mode=args.debug, 144 | ) 145 | 146 | Exp = Exp_All_Task_SUP 147 | 148 | if args.is_training: 149 | for ii in range(args.itr): 150 | # setting record of experiments 151 | setting = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}_{}'.format( 152 | args.task_name, 153 | args.model_id, 154 | args.model, 155 | args.data, 156 | args.features, 157 | args.d_model, 158 | args.e_layers, 159 | args.des, ii) 160 | 161 | exp = Exp(args) # set experiments 162 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 163 | exp.train(setting) 164 | else: 165 | ii = 0 166 | setting = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}_{}'.format( 167 | args.task_name, 168 | args.model_id, 169 | args.model, 170 | args.data, 171 | args.features, 172 | args.d_model, 173 | args.e_layers, 174 | args.des, ii) 175 | 176 | exp = Exp(args) # set experiments 177 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 178 | exp.test(setting, load_pretrain=True) 179 | torch.cuda.empty_cache() 180 | -------------------------------------------------------------------------------- /run_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from exp.exp_pretrain import Exp_All_Task as Exp_All_Task_SSL 4 | import random 5 | import numpy as np 6 | import wandb 7 | from utils.ddp import is_main_process, init_distributed_mode 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='UniTS Pretrain') 11 | parser.add_argument('--fix_seed', type=int, default=None, help='seed') 12 | # basic config 13 | parser.add_argument('--task_name', type=str, required=False, default='ALL_task', 14 | help='task name') 15 | parser.add_argument('--is_training', type=int, 16 | required=True, default=1, help='status') 17 | parser.add_argument('--model_id', type=str, required=True, 18 | default='test', help='model id') 19 | parser.add_argument('--model', type=str, required=True, default='UniTS', 20 | help='model name') 21 | 22 | # data loader 23 | parser.add_argument('--data', type=str, required=False, 24 | default='All', help='dataset type') 25 | parser.add_argument('--features', type=str, default='M', 26 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 27 | parser.add_argument('--target', type=str, default='OT', 28 | help='target feature in S or MS task') 29 | parser.add_argument('--freq', type=str, default='h', 30 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 31 | parser.add_argument('--task_data_config_path', type=str, 32 | default='exp/all_task_pretrain.yaml', help='root path of the task and data yaml file') 33 | parser.add_argument('--subsample_pct', type=float, 34 | default=None, help='subsample percent') 35 | 36 | # pretrain 37 | parser.add_argument('--right_prob', type=float, 38 | default=1.0, help='right mask prob') 39 | parser.add_argument('--min_mask_ratio', type=float, 40 | default=0.5, help='min right mask prob') 41 | parser.add_argument('--max_mask_ratio', type=float, 42 | default=0.8, help='max right mask prob') 43 | parser.add_argument('--min_keep_ratio', type=float, default=None, 44 | help='min crop ratio for various length in pretraining') 45 | 46 | # ddp 47 | parser.add_argument('--local-rank', type=int, help='local rank') 48 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 49 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 50 | parser.add_argument('--num_workers', type=int, default=0, 51 | help='data loader num workers') 52 | 53 | # optimization 54 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 55 | parser.add_argument('--train_epochs', type=int, 56 | default=10, help='train epochs') 57 | parser.add_argument('--warmup_epochs', type=int, 58 | default=0, help='warmup epochs') 59 | parser.add_argument('--batch_size', type=int, default=32, 60 | help='batch size of train input data') 61 | parser.add_argument('--acc_it', type=int, default=32, 62 | help='acc iteration to enlarge batch size') 63 | parser.add_argument('--learning_rate', type=float, 64 | default=0.0001, help='optimizer learning rate') 65 | parser.add_argument('--min_lr', type=float, default=1e-6, 66 | help='optimizer learning rate') 67 | parser.add_argument('--beta2', type=float, 68 | default=0.999, help='optimizer beta2') 69 | parser.add_argument('--weight_decay', type=float, 70 | default=0.0, help='optimizer weight decay') 71 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 72 | parser.add_argument('--eps', type=float, default=1e-08, 73 | help='eps for optimizer') 74 | parser.add_argument('--des', type=str, default='test', 75 | help='exp description') 76 | parser.add_argument('--debug', type=str, 77 | default='enabled', help='disabled') 78 | parser.add_argument('--clip_grad', type=float, default=None, help="""Maximal parameter 79 | gradient norm if using gradient clipping.""") 80 | parser.add_argument('--checkpoints', type=str, 81 | default='./checkpoints/', help='location of model checkpoints') 82 | 83 | parser.add_argument("--memory_check", action="store_true", default=True) 84 | parser.add_argument("--large_model", action="store_true", default=True) 85 | 86 | # model settings 87 | parser.add_argument('--d_model', type=int, default=512, 88 | help='dimension of model') 89 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') 90 | parser.add_argument('--e_layers', type=int, default=2, 91 | help='num of encoder layers') 92 | parser.add_argument("--patch_len", type=int, default=16) 93 | parser.add_argument("--stride", type=int, default=8) 94 | parser.add_argument("--prompt_num", type=int, default=10) 95 | 96 | args = parser.parse_args() 97 | init_distributed_mode(args) 98 | 99 | print('Args in experiment:') 100 | print(args) 101 | if args.fix_seed is not None: 102 | random.seed(args.fix_seed) 103 | torch.manual_seed(args.fix_seed) 104 | np.random.seed(args.fix_seed) 105 | exp_name = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}'.format( 106 | args.task_name, 107 | args.model_id, 108 | args.model, 109 | args.data, 110 | args.features, 111 | args.d_model, 112 | args.e_layers, 113 | args.des) 114 | 115 | if is_main_process(): 116 | wandb.init( 117 | name=exp_name, 118 | # set the wandb project where this run will be logged 119 | project="pretrain", 120 | # track hyperparameters and run metadata 121 | config=args, 122 | mode=args.debug, 123 | ) 124 | Exp = Exp_All_Task_SSL 125 | 126 | if args.is_training: 127 | for ii in range(args.itr): 128 | # setting record of experiments 129 | setting = '{}_{}_{}_{}_ft{}_dm{}_el{}_{}_{}'.format( 130 | args.task_name, 131 | args.model_id, 132 | args.model, 133 | args.data, 134 | args.features, 135 | args.d_model, 136 | args.e_layers, 137 | args.des, ii) 138 | 139 | exp = Exp(args) # set experiments 140 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 141 | exp.train(setting) 142 | 143 | torch.cuda.empty_cache() 144 | -------------------------------------------------------------------------------- /scripts/few_shot_anomaly_detection/UniTS_finetune_few_shot_anomaly_detection.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=anomaly_detection 4 | exp_name=finetune_few_shot_anomaly_detection_pct05 5 | 6 | # Path to the supervised checkpoint 7 | # get ssl pretrained checkpoint: scripts/pretrain_prompt_learning/UniTS_pretrain_x32.sh 8 | ckpt_path=newcheckpoints/units_x32_pretrain_checkpoint.pth 9 | random_port=$((RANDOM % 9000 + 1000)) 10 | 11 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 12 | --fix_seed 2021 \ 13 | --is_training 1 \ 14 | --subsample_pct 0.05 \ 15 | --model_id $exp_name \ 16 | --pretrained_weight $ckpt_path \ 17 | --model $model_name \ 18 | --prompt_num 10 \ 19 | --patch_len 16 \ 20 | --stride 16 \ 21 | --e_layers 3 \ 22 | --d_model 32 \ 23 | --des 'Exp' \ 24 | --itr 1 \ 25 | --lradj finetune_anl \ 26 | --learning_rate 5e-4 \ 27 | --weight_decay 1e-3 \ 28 | --train_epochs 10 \ 29 | --batch_size 32 \ 30 | --acc_it 32 \ 31 | --dropout 0 \ 32 | --debug $wandb_mode \ 33 | --project_name $project_name \ 34 | --clip_grad 100 \ 35 | --task_data_config_path data_provider/anomaly_detection.yaml -------------------------------------------------------------------------------- /scripts/few_shot_anomaly_detection/UniTS_prompt_tuning_few_shot_anomaly_detection.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=anomaly_detection 4 | exp_name=prompt_tuning_few_shot_anomaly_detection_pct05 5 | 6 | 7 | # Path to the supervised checkpoint 8 | # get ssl pretrained checkpoint: scripts/pretrain_prompt_learning/UniTS_pretrain_x32.sh 9 | ckpt_path=newcheckpoints/units_x32_pretrain_checkpoint.pth 10 | random_port=$((RANDOM % 9000 + 1000)) 11 | 12 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 13 | --fix_seed 2021 \ 14 | --is_training 1 \ 15 | --subsample_pct 0.05 \ 16 | --model_id $exp_name \ 17 | --pretrained_weight $ckpt_path \ 18 | --model $model_name \ 19 | --prompt_num 10 \ 20 | --patch_len 16 \ 21 | --stride 16 \ 22 | --e_layers 3 \ 23 | --d_model 32 \ 24 | --des 'Exp' \ 25 | --itr 1 \ 26 | --lradj prompt_tuning \ 27 | --learning_rate 5e-5 \ 28 | --weight_decay 1e-2 \ 29 | --train_epochs 0 \ 30 | --prompt_tune_epoch 10 \ 31 | --batch_size 32 \ 32 | --acc_it 32 \ 33 | --dropout 0.0 \ 34 | --debug $wandb_mode \ 35 | --project_name $project_name \ 36 | --clip_grad 100 \ 37 | --task_data_config_path data_provider/anomaly_detection.yaml -------------------------------------------------------------------------------- /scripts/few_shot_imputation/UniTS_finetune_few_shot_imputation_mask025.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_imputation 4 | exp_name=fewshot_imputation_finetune_mask025 5 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth 6 | 7 | random_port=$((RANDOM % 9000 + 1000)) 8 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 9 | --is_training 1 \ 10 | --fix_seed 2021 \ 11 | --model_id $exp_name \ 12 | --subsample_pct 0.1 \ 13 | --mask_rate 0.25 \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model 64 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --lradj finetune_imp \ 23 | --learning_rate 3e-4 \ 24 | --weight_decay 5e-6 \ 25 | --train_epochs 20 \ 26 | --batch_size 32 \ 27 | --acc_it 32 \ 28 | --clip_grad 1.0 \ 29 | --debug $wandb_mode \ 30 | --project_name $project_name \ 31 | --pretrained_weight $ckpt_path \ 32 | --task_data_config_path data_provider/imputation.yaml \ -------------------------------------------------------------------------------- /scripts/few_shot_imputation/UniTS_finetune_few_shot_imputation_mask050.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_imputation 4 | exp_name=fewshot_imputation_finetune_mask050 5 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth 6 | 7 | random_port=$((RANDOM % 9000 + 1000)) 8 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 9 | --is_training 1 \ 10 | --fix_seed 2021 \ 11 | --model_id $exp_name \ 12 | --subsample_pct 0.1 \ 13 | --mask_rate 0.50 \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model 64 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --lradj finetune_imp \ 23 | --learning_rate 3e-4 \ 24 | --weight_decay 5e-6 \ 25 | --train_epochs 20 \ 26 | --batch_size 32 \ 27 | --acc_it 32 \ 28 | --clip_grad 1.0 \ 29 | --debug $wandb_mode \ 30 | --project_name $project_name \ 31 | --pretrained_weight $ckpt_path \ 32 | --task_data_config_path data_provider/imputation.yaml \ -------------------------------------------------------------------------------- /scripts/few_shot_imputation/UniTS_prompt_tuning_few_shot_imputation_mask025.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_imputation 4 | exp_name=fewshot_imputation_prompt_tuning_mask025 5 | 6 | # Path to the SSL pre-trained checkpoint 7 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth 8 | 9 | random_port=$((RANDOM % 9000 + 1000)) 10 | ckpt_path=newcheckpoints/units_prompt_tuning_few_shot_imputation_mask025.pth 11 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 12 | --is_training 1 \ 13 | --fix_seed 2021 \ 14 | --model_id $exp_name \ 15 | --subsample_pct 0.1 \ 16 | --mask_rate 0.25 \ 17 | --model $model_name \ 18 | --prompt_num 10 \ 19 | --patch_len 16 \ 20 | --stride 16 \ 21 | --e_layers 3 \ 22 | --d_model 128 \ 23 | --des 'Exp' \ 24 | --itr 1 \ 25 | --prompt_tune_epoch 20 \ 26 | --train_epochs 0 \ 27 | --lradj prompt_tuning \ 28 | --learning_rate 5e-3 \ 29 | --weight_decay 0 \ 30 | --batch_size 32 \ 31 | --acc_it 32 \ 32 | --clip_grad 1.0 \ 33 | --dropout 0 \ 34 | --debug $wandb_mode \ 35 | --project_name $project_name \ 36 | --pretrained_weight $ckpt_path \ 37 | --task_data_config_path data_provider/imputation.yaml \ -------------------------------------------------------------------------------- /scripts/few_shot_imputation/UniTS_prompt_tuning_few_shot_imputation_mask050.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_imputation 4 | exp_name=fewshot_imputation_prompt_tuning_mask050 5 | 6 | # Path to the SSL pre-trained checkpoint 7 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth 8 | 9 | random_port=$((RANDOM % 9000 + 1000)) 10 | 11 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 12 | --is_training 1 \ 13 | --fix_seed 2021 \ 14 | --model_id $exp_name \ 15 | --subsample_pct 0.1 \ 16 | --mask_rate 0.50 \ 17 | --model $model_name \ 18 | --prompt_num 10 \ 19 | --patch_len 16 \ 20 | --stride 16 \ 21 | --e_layers 3 \ 22 | --d_model 128 \ 23 | --des 'Exp' \ 24 | --itr 1 \ 25 | --prompt_tune_epoch 20 \ 26 | --train_epochs 0 \ 27 | --lradj prompt_tuning \ 28 | --learning_rate 5e-3 \ 29 | --weight_decay 0 \ 30 | --batch_size 32 \ 31 | --acc_it 32 \ 32 | --clip_grad 1.0 \ 33 | --dropout 0 \ 34 | --debug $wandb_mode \ 35 | --project_name $project_name \ 36 | --pretrained_weight $ckpt_path \ 37 | --task_data_config_path data_provider/imputation.yaml \ -------------------------------------------------------------------------------- /scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct05.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_newdata 4 | exp_name=fewshot_newdata_finetune_pct05 5 | random_port=$((RANDOM % 9000 + 1000)) 6 | # Path to the supervised checkpoint 7 | # get supervised checkpoint: scripts/supervised/UniTS_supervised_x64.sh 8 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth 9 | torchrun --nnodes 1 --master_port $random_port run.py \ 10 | --is_training 1 \ 11 | --fix_seed 2021 \ 12 | --model_id $exp_name \ 13 | --subsample_pct 0.05 \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model 64 \ 20 | --des 'Exp' \ 21 | --train_epochs 5 \ 22 | --learning_rate 1e-4 \ 23 | --weight_decay 1e-5 \ 24 | --lradj supervised \ 25 | --dropout 0.1 \ 26 | --acc_it 8 \ 27 | --clip_grad 100 \ 28 | --debug $wandb_mode \ 29 | --project_name $project_name \ 30 | --pretrained_weight $ckpt_path \ 31 | --task_data_config_path data_provider/fewshot_new_task.yaml 32 | 33 | -------------------------------------------------------------------------------- /scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct15.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_newdata 4 | exp_name=fewshot_newdata_finetune_pct15 5 | random_port=$((RANDOM % 9000 + 1000)) 6 | # Path to the supervised checkpoint 7 | # get supervised checkpoint: scripts/supervised/UniTS_supervised_x64.sh 8 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth 9 | torchrun --nnodes 1 --master_port $random_port run.py \ 10 | --is_training 1 \ 11 | --fix_seed 2021 \ 12 | --model_id $exp_name \ 13 | --subsample_pct 0.15 \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model 64 \ 20 | --des 'Exp' \ 21 | --train_epochs 5 \ 22 | --learning_rate 1e-4 \ 23 | --weight_decay 1e-5 \ 24 | --lradj supervised \ 25 | --dropout 0.1 \ 26 | --acc_it 8 \ 27 | --clip_grad 100 \ 28 | --debug $wandb_mode \ 29 | --project_name $project_name \ 30 | --pretrained_weight $ckpt_path \ 31 | --task_data_config_path data_provider/fewshot_new_task.yaml 32 | 33 | -------------------------------------------------------------------------------- /scripts/few_shot_newdata/UniTS_finetune_few_shot_newdata_pct20.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_newdata 4 | exp_name=fewshot_newdata_finetune_pct20 5 | random_port=$((RANDOM % 9000 + 1000)) 6 | # Path to the supervised checkpoint 7 | # get supervised checkpoint: scripts/supervised/UniTS_supervised_x64.sh 8 | ckpt_path=newcheckpoints/units_x64_supervised_checkpoint.pth 9 | torchrun --nnodes 1 --master_port $random_port run.py \ 10 | --is_training 1 \ 11 | --fix_seed 2021 \ 12 | --model_id $exp_name \ 13 | --subsample_pct 0.20 \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model 64 \ 20 | --des 'Exp' \ 21 | --train_epochs 5 \ 22 | --learning_rate 1e-4 \ 23 | --weight_decay 1e-5 \ 24 | --lradj supervised \ 25 | --dropout 0.1 \ 26 | --acc_it 8 \ 27 | --clip_grad 100 \ 28 | --debug $wandb_mode \ 29 | --project_name $project_name \ 30 | --pretrained_weight $ckpt_path \ 31 | --task_data_config_path data_provider/fewshot_new_task.yaml 32 | 33 | -------------------------------------------------------------------------------- /scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct05.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_newdata 4 | exp_name=fewshot_newdata_prompt_tuning_pct05 5 | random_port=$((RANDOM % 9000 + 1000)) 6 | 7 | # Path to the SSL pre-trained checkpoint 8 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth 9 | 10 | torchrun --nnodes 1 --master_port $random_port run.py \ 11 | --is_training 1 \ 12 | --fix_seed 2021 \ 13 | --model_id $exp_name \ 14 | --subsample_pct 0.05 \ 15 | --model $model_name \ 16 | --prompt_num 10 \ 17 | --patch_len 16 \ 18 | --stride 16 \ 19 | --e_layers 3 \ 20 | --d_model 128 \ 21 | --des 'Exp' \ 22 | --prompt_tune_epoch 10 \ 23 | --train_epochs 0 \ 24 | --lradj prompt_tuning \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 1e-4 \ 27 | --dropout 0 \ 28 | --acc_it 8 \ 29 | --clip_grad 100 \ 30 | --debug $wandb_mode \ 31 | --project_name $project_name \ 32 | --pretrained_weight $ckpt_path \ 33 | --task_data_config_path data_provider/fewshot_new_task.yaml -------------------------------------------------------------------------------- /scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct15.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_newdata 4 | exp_name=fewshot_newdata_prompt_tuning_pct15 5 | random_port=$((RANDOM % 9000 + 1000)) 6 | 7 | # Path to the SSL pre-trained checkpoint 8 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth 9 | 10 | torchrun --nnodes 1 --master_port $random_port run.py \ 11 | --is_training 1 \ 12 | --fix_seed 2021 \ 13 | --model_id $exp_name \ 14 | --subsample_pct 0.15 \ 15 | --model $model_name \ 16 | --prompt_num 10 \ 17 | --patch_len 16 \ 18 | --stride 16 \ 19 | --e_layers 3 \ 20 | --d_model 128 \ 21 | --des 'Exp' \ 22 | --prompt_tune_epoch 10 \ 23 | --train_epochs 0 \ 24 | --lradj prompt_tuning \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 1e-4 \ 27 | --dropout 0 \ 28 | --acc_it 8 \ 29 | --clip_grad 100 \ 30 | --debug $wandb_mode \ 31 | --project_name $project_name \ 32 | --pretrained_weight $ckpt_path \ 33 | --task_data_config_path data_provider/fewshot_new_task.yaml -------------------------------------------------------------------------------- /scripts/few_shot_newdata/UniTS_prompt_tuning_few_shot_newdata_pct20.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | wandb_mode=online 3 | project_name=fewshot_newdata 4 | exp_name=fewshot_newdata_prompt_tuning_pct20 5 | random_port=$((RANDOM % 9000 + 1000)) 6 | 7 | # Path to the SSL pre-trained checkpoint 8 | ckpt_path=newcheckpoints/units_x128_pretrain_checkpoint.pth 9 | 10 | torchrun --nnodes 1 --master_port $random_port run.py \ 11 | --is_training 1 \ 12 | --fix_seed 2021 \ 13 | --model_id $exp_name \ 14 | --subsample_pct 0.20 \ 15 | --model $model_name \ 16 | --prompt_num 10 \ 17 | --patch_len 16 \ 18 | --stride 16 \ 19 | --e_layers 3 \ 20 | --d_model 128 \ 21 | --des 'Exp' \ 22 | --prompt_tune_epoch 10 \ 23 | --train_epochs 0 \ 24 | --lradj prompt_tuning \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 1e-4 \ 27 | --dropout 0 \ 28 | --acc_it 8 \ 29 | --clip_grad 100 \ 30 | --debug $wandb_mode \ 31 | --project_name $project_name \ 32 | --pretrained_weight $ckpt_path \ 33 | --task_data_config_path data_provider/fewshot_new_task.yaml -------------------------------------------------------------------------------- /scripts/pretrain_prompt_learning/UniTS_pretrain_x128.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | exp_name=UniTS_pretrain_x128 3 | wandb_mode=online 4 | ptune_name=prompt_tuning 5 | 6 | d_model=128 7 | 8 | random_port=$((RANDOM % 9000 + 1000)) 9 | 10 | # Pretrain 11 | torchrun --nnodes 1 --nproc-per-node 2 --master_port $random_port run_pretrain.py \ 12 | --is_training 1 \ 13 | --model_id $exp_name \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model $d_model \ 20 | --des 'Exp' \ 21 | --acc_it 128 \ 22 | --batch_size 32 \ 23 | --learning_rate 5e-5 \ 24 | --min_lr 1e-4 \ 25 | --weight_decay 5e-6 \ 26 | --train_epochs 10 \ 27 | --warmup_epochs 0 \ 28 | --min_keep_ratio 0.5 \ 29 | --right_prob 0.5 \ 30 | --min_mask_ratio 0.7 \ 31 | --max_mask_ratio 0.8 \ 32 | --debug $wandb_mode \ 33 | --task_data_config_path data_provider/multi_task_pretrain.yaml 34 | 35 | # Prompt tuning 36 | torchrun --nnodes 1 --master_port $random_port run.py \ 37 | --is_training 1 \ 38 | --model_id $exp_name \ 39 | --model $model_name \ 40 | --lradj prompt_tuning \ 41 | --prompt_num 10 \ 42 | --patch_len 16 \ 43 | --stride 16 \ 44 | --e_layers 3 \ 45 | --d_model $d_model \ 46 | --des 'Exp' \ 47 | --itr 1 \ 48 | --learning_rate 3e-3 \ 49 | --weight_decay 0 \ 50 | --prompt_tune_epoch 2 \ 51 | --train_epochs 0 \ 52 | --acc_it 32 \ 53 | --debug $wandb_mode \ 54 | --project_name $ptune_name \ 55 | --clip_grad 100 \ 56 | --pretrained_weight auto \ 57 | --task_data_config_path data_provider/multi_task.yaml 58 | -------------------------------------------------------------------------------- /scripts/pretrain_prompt_learning/UniTS_pretrain_x32.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | exp_name=UniTS_pretrain_x32 3 | wandb_mode=online 4 | ptune_name=prompt_tuning 5 | 6 | d_model=32 7 | 8 | random_port=$((RANDOM % 9000 + 1000)) 9 | 10 | # Pretrain 11 | torchrun --nnodes 1 --nproc-per-node 1 --master_port $random_port run_pretrain.py \ 12 | --is_training 1 \ 13 | --model_id $exp_name \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model $d_model \ 20 | --des 'Exp' \ 21 | --acc_it 128 \ 22 | --batch_size 32 \ 23 | --learning_rate 5e-5 \ 24 | --min_lr 1e-4 \ 25 | --weight_decay 5e-6 \ 26 | --train_epochs 10 \ 27 | --warmup_epochs 0 \ 28 | --min_keep_ratio 0.5 \ 29 | --right_prob 0.5 \ 30 | --min_mask_ratio 0.7 \ 31 | --max_mask_ratio 0.8 \ 32 | --debug $wandb_mode \ 33 | --task_data_config_path data_provider/multi_task_pretrain.yaml 34 | 35 | # Prompt tuning 36 | torchrun --nnodes 1 --master_port $random_port run.py \ 37 | --is_training 1 \ 38 | --model_id $exp_name \ 39 | --model $model_name \ 40 | --lradj prompt_tuning \ 41 | --prompt_num 10 \ 42 | --patch_len 16 \ 43 | --stride 16 \ 44 | --e_layers 3 \ 45 | --d_model $d_model \ 46 | --des 'Exp' \ 47 | --itr 1 \ 48 | --learning_rate 1e-3 \ 49 | --weight_decay 0 \ 50 | --prompt_tune_epoch 2 \ 51 | --train_epochs 0 \ 52 | --acc_it 32 \ 53 | --debug $wandb_mode \ 54 | --project_name $ptune_name \ 55 | --clip_grad 100 \ 56 | --pretrained_weight auto \ 57 | --task_data_config_path data_provider/multi_task.yaml 58 | -------------------------------------------------------------------------------- /scripts/pretrain_prompt_learning/UniTS_pretrain_x64.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | exp_name=UniTS_pretrain_x64 3 | wandb_mode=online 4 | ptune_name=prompt_tuning 5 | 6 | d_model=64 7 | 8 | random_port=$((RANDOM % 9000 + 1000)) 9 | 10 | # Pretrain 11 | torchrun --nnodes 1 --nproc-per-node 2 --master_port $random_port run_pretrain.py \ 12 | --is_training 1 \ 13 | --model_id $exp_name \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model $d_model \ 20 | --des 'Exp' \ 21 | --acc_it 128 \ 22 | --batch_size 32 \ 23 | --learning_rate 5e-5 \ 24 | --min_lr 1e-4 \ 25 | --weight_decay 5e-6 \ 26 | --train_epochs 10 \ 27 | --warmup_epochs 0 \ 28 | --min_keep_ratio 0.5 \ 29 | --right_prob 0.5 \ 30 | --min_mask_ratio 0.7 \ 31 | --max_mask_ratio 0.8 \ 32 | --debug $wandb_mode \ 33 | --task_data_config_path data_provider/multi_task_pretrain.yaml 34 | 35 | # Prompt tuning 36 | torchrun --nnodes 1 --master_port $random_port run.py \ 37 | --is_training 1 \ 38 | --model_id $exp_name \ 39 | --model $model_name \ 40 | --lradj prompt_tuning \ 41 | --prompt_num 10 \ 42 | --patch_len 16 \ 43 | --stride 16 \ 44 | --e_layers 3 \ 45 | --d_model $d_model \ 46 | --des 'Exp' \ 47 | --itr 1 \ 48 | --learning_rate 3e-3 \ 49 | --weight_decay 0 \ 50 | --prompt_tune_epoch 2 \ 51 | --train_epochs 0 \ 52 | --acc_it 32 \ 53 | --debug $wandb_mode \ 54 | --project_name $ptune_name \ 55 | --clip_grad 100 \ 56 | --pretrained_weight auto \ 57 | --task_data_config_path data_provider/multi_task.yaml -------------------------------------------------------------------------------- /scripts/supervised_learning/UniTS_supervised.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS 2 | exp_name=UniTS_supervised_x64 3 | wandb_mode=online 4 | project_name=supervised_learning 5 | 6 | random_port=$((RANDOM % 9000 + 1000)) 7 | 8 | # Supervised learning 9 | torchrun --nnodes 1 --nproc-per-node=1 --master_port $random_port run.py \ 10 | --is_training 1 \ 11 | --model_id $exp_name \ 12 | --model $model_name \ 13 | --lradj supervised \ 14 | --prompt_num 10 \ 15 | --patch_len 16 \ 16 | --stride 16 \ 17 | --e_layers 3 \ 18 | --d_model 64 \ 19 | --des 'Exp' \ 20 | --learning_rate 1e-4 \ 21 | --weight_decay 5e-6 \ 22 | --train_epochs 5 \ 23 | --batch_size 32 \ 24 | --acc_it 32 \ 25 | --debug $wandb_mode \ 26 | --project_name $project_name \ 27 | --clip_grad 100 \ 28 | --task_data_config_path data_provider/multi_task.yaml -------------------------------------------------------------------------------- /scripts/zero_shot/UniTS_forecast_new_length_unify.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | model_name=UniTS 4 | exp_name=UniTS_pretrain_x128 5 | wandb_mode=disabled 6 | ptune_name=prompt_tuning 7 | 8 | d_model=128 9 | 10 | random_port=$((RANDOM % 9000 + 1000)) 11 | 12 | # Get the pretrained model 13 | # cripts/pretrain_prompt_learning/UniTS_pretrain_x128.sh 14 | ckpt_path=pretrain_ckpt.pth 15 | 16 | 17 | offset=384 18 | torchrun --nnodes 1 --master_port $random_port run.py \ 19 | --is_training 0 \ 20 | --zero_shot_forecasting_new_length unify \ 21 | --max_offset 384 \ 22 | --offset $offset \ 23 | --model_id $exp_name \ 24 | --model $model_name \ 25 | --lradj prompt_tuning \ 26 | --prompt_num 10 \ 27 | --patch_len 16 \ 28 | --stride 16 \ 29 | --e_layers 3 \ 30 | --d_model $d_model \ 31 | --des 'Exp' \ 32 | --itr 1 \ 33 | --debug $wandb_mode \ 34 | --project_name $ptune_name \ 35 | --pretrained_weight $ckpt_path \ 36 | --task_data_config_path data_provider/multitask_zero_shot_new_length.yaml -------------------------------------------------------------------------------- /scripts/zero_shot/UniTS_zeroshot_newdata.sh: -------------------------------------------------------------------------------- 1 | model_name=UniTS_zeroshot 2 | exp_name=UniTS_zeroshot_pretrain_x64 3 | wandb_mode=online 4 | ptune_name=zeroshot_newdata 5 | 6 | d_model=64 7 | 8 | random_port=$((RANDOM % 9000 + 1000)) 9 | 10 | # Pretrain of zero-shot version of UniTS 11 | torchrun --nnodes 1 --nproc-per-node 2 --master_port $random_port run_pretrain.py \ 12 | --is_training 1 \ 13 | --model_id $exp_name \ 14 | --model $model_name \ 15 | --prompt_num 10 \ 16 | --patch_len 16 \ 17 | --stride 16 \ 18 | --e_layers 3 \ 19 | --d_model $d_model \ 20 | --des 'Exp' \ 21 | --acc_it 128 \ 22 | --batch_size 32 \ 23 | --learning_rate 5e-5 \ 24 | --min_lr 1e-4 \ 25 | --weight_decay 5e-6 \ 26 | --train_epochs 10 \ 27 | --warmup_epochs 0 \ 28 | --min_keep_ratio 0.5 \ 29 | --right_prob 0.5 \ 30 | --min_mask_ratio 0.7 \ 31 | --max_mask_ratio 0.8 \ 32 | --debug $wandb_mode \ 33 | --task_data_config_path data_provider/multi_task_pretrain.yaml 34 | 35 | # Zero-shot test on new forecasting datasets 36 | # Note: The inference in this code test all samples of the dataset, 37 | # which is not the same as the original paper that only test 1 sample for each dataset. 38 | torchrun --nnodes 1 --master_port $random_port run.py \ 39 | --is_training 0 \ 40 | --model_id $exp_name \ 41 | --model $model_name \ 42 | --prompt_num 10 \ 43 | --patch_len 16 \ 44 | --stride 16 \ 45 | --e_layers 3 \ 46 | --d_model $d_model \ 47 | --des 'Exp' \ 48 | --debug $wandb_mode \ 49 | --project_name $ptune_name \ 50 | --pretrained_weight auto \ 51 | --task_data_config_path data_provider/zeroshot_task.yaml -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mims-harvard/UniTS/0e0281482864017cac8832b2651906ff5375a34e/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BalancedDataLoaderIterator: 5 | def __init__(self, dataloaders): 6 | self.dataloaders = dataloaders 7 | 8 | self.num_dataloaders = len(dataloaders) 9 | 10 | max_length = max(len(dataloader) for dataloader in dataloaders) 11 | 12 | length_list = [len(dataloader) for dataloader in dataloaders] 13 | print("data loader length:", length_list) 14 | print("max dataloader length:", max_length, 15 | "epoch iteration:", max_length * self.num_dataloaders) 16 | self.total_length = max_length * self.num_dataloaders 17 | self.current_iteration = 0 18 | self.probabilities = torch.ones( 19 | self.num_dataloaders, dtype=torch.float) / self.num_dataloaders 20 | 21 | def __iter__(self): 22 | self.iterators = [iter(dataloader) for dataloader in self.dataloaders] 23 | self.current_iteration = 0 24 | return self 25 | 26 | def __next__(self): 27 | if self.current_iteration >= self.total_length: 28 | raise StopIteration 29 | 30 | chosen_index = torch.multinomial(self.probabilities, 1).item() 31 | try: 32 | sample = next(self.iterators[chosen_index]) 33 | except StopIteration: 34 | self.iterators[chosen_index] = iter(self.dataloaders[chosen_index]) 35 | sample = next(self.iterators[chosen_index]) 36 | 37 | self.current_iteration += 1 38 | return sample, chosen_index 39 | 40 | def __len__(self): 41 | return self.total_length 42 | 43 | def generate_fake_samples_for_batch(self, dataloader_id, batch_size): 44 | if dataloader_id >= len(self.dataloaders) or dataloader_id < 0: 45 | raise ValueError("Invalid dataloader ID") 46 | 47 | dataloader = self.dataloaders[dataloader_id] 48 | iterator = iter(dataloader) 49 | 50 | try: 51 | sample_batch = next(iterator) 52 | fake_samples = [] 53 | 54 | for sample in sample_batch: 55 | if isinstance(sample, torch.Tensor): 56 | fake_sample = torch.zeros( 57 | [batch_size] + list(sample.shape)[1:]) 58 | fake_samples.append(fake_sample) 59 | else: 60 | pass 61 | 62 | return fake_samples, dataloader_id 63 | except StopIteration: 64 | return None 65 | -------------------------------------------------------------------------------- /utils/ddp.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import torch 3 | 4 | 5 | def is_dist_avail_and_initialized(): 6 | if not dist.is_available(): 7 | return False 8 | if not dist.is_initialized(): 9 | return False 10 | return True 11 | 12 | 13 | def get_world_size(): 14 | if not is_dist_avail_and_initialized(): 15 | return 1 16 | return dist.get_world_size() 17 | 18 | 19 | def get_rank(): 20 | if not is_dist_avail_and_initialized(): 21 | return 0 22 | return dist.get_rank() 23 | 24 | 25 | def is_main_process(): 26 | return get_rank() == 0 27 | 28 | 29 | def init_distributed_mode(args): 30 | 31 | dist.init_process_group( 32 | backend="nccl", 33 | ) 34 | rank = dist.get_rank() 35 | torch.cuda.set_device(rank) 36 | torch.cuda.empty_cache() 37 | print(f"Start running basic DDP on rank {rank}.") 38 | 39 | dist.barrier() 40 | setup_for_distributed(rank == 0) 41 | 42 | 43 | def setup_for_distributed(is_master): 44 | """ 45 | This function disables printing when not in master process 46 | """ 47 | import builtins as __builtin__ 48 | builtin_print = __builtin__.print 49 | 50 | def print(*args, **kwargs): 51 | force = kwargs.pop('force', False) 52 | if is_master or force: 53 | builtin_print(*args, **kwargs) 54 | 55 | __builtin__.print = print 56 | 57 | def gather_tensors_from_all_gpus(tensor_list, device_id, to_numpy=True): 58 | """ 59 | Gather tensors from all GPUs in a DDP setup onto each GPU. 60 | 61 | Args: 62 | local_tensors (list of torch.Tensor): List of tensors on the local GPU. 63 | 64 | Returns: 65 | list of torch.Tensor: List of all tensors gathered from all GPUs, available on each GPU. 66 | """ 67 | world_size = dist.get_world_size() 68 | tensor_list = [tensor.to(device_id).contiguous() for tensor in tensor_list] 69 | gathered_tensors = [[] for _ in range(len(tensor_list))] 70 | 71 | # Gathering tensors from all GPUs 72 | for tensor in tensor_list: 73 | # Each GPU will gather tensors from all other GPUs 74 | gathered_list = [torch.empty_like(tensor) for _ in range(world_size)] 75 | dist.all_gather(gathered_list, tensor) 76 | gathered_tensors.append(gathered_list) 77 | del tensor_list 78 | # Flattening the gathered list 79 | flattened_tensors = [ 80 | tensor for sublist in gathered_tensors for tensor in sublist] 81 | del gathered_tensors 82 | if to_numpy: 83 | flattened_tensors_numpy = [tensor.cpu().numpy() 84 | for tensor in flattened_tensors] 85 | del flattened_tensors 86 | 87 | return flattened_tensors_numpy 88 | else: 89 | return flattened_tensors 90 | -------------------------------------------------------------------------------- /utils/layer_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 2 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) 26 | for i in range(num_layers + 1)) 27 | 28 | for n, p in model.named_parameters(): 29 | if not p.requires_grad: 30 | continue 31 | 32 | # no decay: all 1D parameters and model specific ones 33 | if p.ndim == 1 or n in no_weight_decay_list: 34 | g_decay = "no_decay" 35 | this_decay = 0. 36 | else: 37 | g_decay = "decay" 38 | this_decay = weight_decay 39 | 40 | layer_id = get_layer_id_for_model(n, num_layers) 41 | group_name = "layer_%d_%s" % (layer_id, g_decay) 42 | 43 | if group_name not in param_group_names: 44 | this_scale = layer_scales[layer_id] 45 | 46 | param_group_names[group_name] = { 47 | "lr_scale": this_scale, 48 | "weight_decay": this_decay, 49 | "params": [], 50 | } 51 | param_groups[group_name] = { 52 | "lr_scale": this_scale, 53 | "weight_decay": this_decay, 54 | "params": [], 55 | } 56 | print("name: %s, layer_id: %d, group_name: %s, lr_scale: %f, weight_decay: %f" % ( 57 | n, layer_id, group_name, param_group_names[group_name]["lr_scale"], param_group_names[group_name]["weight_decay"])) 58 | 59 | param_group_names[group_name]["params"].append(n) 60 | param_groups[group_name]["params"].append(p) 61 | 62 | return list(param_groups.values()) 63 | 64 | 65 | def get_layer_id_for_model(name, num_layers): 66 | """ 67 | Assign a parameter with its layer id 68 | """ 69 | if name in ['cls_token', 'patch_embeddings', 'position_embedding', 'prompt2forecat', 'prompt_tokens', 'mask_tokens', 'cls_tokens', 'category_tokens']: 70 | return 0 71 | elif name.startswith('input_encoders'): 72 | return 1 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 2 75 | else: 76 | return num_layers 77 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | # Part of the file is from https://github.com/thuml/Time-Series-Library/blob/main/utils/losses.py 2 | """ 3 | Loss functions for PyTorch. 4 | """ 5 | 6 | import torch as t 7 | import torch.nn as nn 8 | import numpy as np 9 | import pdb 10 | import torch.nn as nn 11 | 12 | 13 | def divide_no_nan(a, b): 14 | """ 15 | a/b where the resulted NaN or Inf are replaced by 0. 16 | """ 17 | result = a / b 18 | result[result != result] = .0 19 | result[result == np.inf] = .0 20 | return result 21 | 22 | 23 | class mape_loss(nn.Module): 24 | def __init__(self): 25 | super(mape_loss, self).__init__() 26 | 27 | def forward(self, insample: t.Tensor, freq: int, 28 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 29 | """ 30 | MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error 31 | 32 | :param forecast: Forecast values. Shape: batch, time 33 | :param target: Target values. Shape: batch, time 34 | :param mask: 0/1 mask. Shape: batch, time 35 | :return: Loss value 36 | """ 37 | weights = divide_no_nan(mask, target) 38 | return t.mean(t.abs((forecast - target) * weights)) 39 | 40 | 41 | class smape_loss(nn.Module): 42 | def __init__(self): 43 | super(smape_loss, self).__init__() 44 | 45 | def forward(self, insample: t.Tensor, freq: int, 46 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 47 | """ 48 | sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993) 49 | 50 | :param forecast: Forecast values. Shape: batch, time 51 | :param target: Target values. Shape: batch, time 52 | :param mask: 0/1 mask. Shape: batch, time 53 | :return: Loss value 54 | """ 55 | return 200 * t.mean(divide_no_nan(t.abs(forecast - target), 56 | t.abs(forecast.data) + t.abs(target.data)) * mask) 57 | 58 | 59 | class mase_loss(nn.Module): 60 | def __init__(self): 61 | super(mase_loss, self).__init__() 62 | 63 | def forward(self, insample: t.Tensor, freq: int, 64 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 65 | """ 66 | MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf 67 | 68 | :param insample: Insample values. Shape: batch, time_i 69 | :param freq: Frequency value 70 | :param forecast: Forecast values. Shape: batch, time_o 71 | :param target: Target values. Shape: batch, time_o 72 | :param mask: 0/1 mask. Shape: batch, time_o 73 | :return: Loss value 74 | """ 75 | masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1) 76 | masked_masep_inv = divide_no_nan(mask, masep[:, None]) 77 | return t.mean(t.abs(target - forecast) * masked_masep_inv) 78 | 79 | 80 | class UnifiedMaskRecLoss(nn.Module): 81 | def __init__(self): 82 | super().__init__() 83 | 84 | def forward_mim_loss(self, target, pred, pad_mask): 85 | loss = (pred - target) ** 2 86 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 87 | 88 | combined_mask = pad_mask.bool() 89 | 90 | loss = (loss * combined_mask).sum() / combined_mask.sum() 91 | return loss 92 | 93 | def forward(self, outputs, target, pad_mask): 94 | student_cls, student_fore, _ = outputs 95 | 96 | mask_loss = self.forward_mim_loss(target, student_fore, pad_mask) 97 | 98 | if student_cls is not None: 99 | cls_loss = self.forward_mim_loss(target, student_cls, pad_mask) 100 | else: 101 | cls_loss = 0.0 * mask_loss 102 | 103 | total_loss = dict(cls_loss=cls_loss, 104 | mask_loss=mask_loss, loss=mask_loss+cls_loss) 105 | return total_loss 106 | -------------------------------------------------------------------------------- /utils/m4_summary.py: -------------------------------------------------------------------------------- 1 | # This file is removed due to LICENSE file constraints. 2 | # You can copy the m4_summary.py from https://github.com/thuml/Time-Series-Library/blob/main/utils/m4_summary.py -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones( 9 | mask_shape, dtype=torch.bool), diagonal=1).to(device) 10 | 11 | @property 12 | def mask(self): 13 | return self._mask 14 | 15 | 16 | class ProbMask(): 17 | def __init__(self, B, H, L, index, scores, device="cpu"): 18 | _mask = torch.ones( 19 | L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 20 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 21 | indicator = _mask_ex[torch.arange(B)[:, None, None], 22 | torch.arange(H)[None, :, None], 23 | index, :].to(device) 24 | self._mask = indicator.view(scores.shape).to(device) 25 | 26 | @property 27 | def mask(self): 28 | return self._mask 29 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * 11 | (pred - pred.mean(0)) ** 2).sum(0)) 12 | return (u / d).mean(-1) 13 | 14 | 15 | def MAE(pred, true): 16 | return np.mean(np.abs(pred - true)) 17 | 18 | 19 | def MSE(pred, true): 20 | return np.mean((pred - true) ** 2) 21 | 22 | 23 | def RMSE(pred, true): 24 | return np.sqrt(MSE(pred, true)) 25 | 26 | 27 | def MAPE(pred, true): 28 | return np.mean(np.abs((pred - true) / true)) 29 | 30 | 31 | def MSPE(pred, true): 32 | return np.mean(np.square((pred - true) / true)) 33 | 34 | 35 | def metric(pred, true): 36 | mae = MAE(pred, true) 37 | mse = MSE(pred, true) 38 | rmse = RMSE(pred, true) 39 | mape = MAPE(pred, true) 40 | mspe = MSPE(pred, true) 41 | 42 | return mae, mse, rmse, mape, mspe 43 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from torch import inf 5 | 6 | plt.switch_backend('agg') 7 | 8 | 9 | def adjust_learning_rate(optimizer, epoch, base_lr, args): 10 | assert args.prompt_tune_epoch >= 0, "args.prompt_tune_epoch >=0!" 11 | if args.lradj == 'prompt_tuning': 12 | if epoch < args.prompt_tune_epoch: 13 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch) // 1))} 14 | elif epoch == args.prompt_tune_epoch: 15 | lr_adjust = {epoch: base_lr} 16 | else: 17 | lr_adjust = {epoch: args.learning_rate * 18 | (0.5 ** (((epoch-args.prompt_tune_epoch) - 1) // 1))} 19 | elif args.lradj == 'supervised': 20 | if epoch <= args.prompt_tune_epoch: 21 | lr_adjust = {epoch: base_lr} 22 | else: 23 | lr_adjust = {epoch: base_lr / 5 * 24 | (0.5 ** (((epoch-args.prompt_tune_epoch)) // 1))} 25 | elif args.lradj == 'finetune_anl': 26 | k = 1 27 | lr_adjust = {epoch: base_lr / (2 ** ((epoch) // k))} 28 | 29 | if epoch in lr_adjust.keys(): 30 | lr = lr_adjust[epoch] 31 | for param_group in optimizer.param_groups: 32 | if "lr_scale" in param_group: 33 | param_group["lr"] = lr * param_group["lr_scale"] 34 | else: 35 | param_group["lr"] = lr 36 | print('Epoch {}: Updating learning rate to {}'.format(epoch+1, lr)) 37 | 38 | 39 | class dotdict(dict): 40 | """dot.notation access to dictionary attributes""" 41 | __getattr__ = dict.get 42 | __setattr__ = dict.__setitem__ 43 | __delattr__ = dict.__delitem__ 44 | 45 | 46 | class StandardScaler(): 47 | def __init__(self, mean, std): 48 | self.mean = mean 49 | self.std = std 50 | 51 | def transform(self, data): 52 | return (data - self.mean) / self.std 53 | 54 | def inverse_transform(self, data): 55 | return (data * self.std) + self.mean 56 | 57 | 58 | def visual(true, preds=None, name='./pic/test.pdf'): 59 | """ 60 | Results visualization 61 | """ 62 | plt.figure() 63 | plt.plot(true, label='GroundTruth', linewidth=2) 64 | if preds is not None: 65 | plt.plot(preds, label='Prediction', linewidth=2) 66 | plt.legend() 67 | plt.savefig(name, bbox_inches='tight') 68 | 69 | 70 | def adjustment(gt, pred): 71 | anomaly_state = False 72 | for i in range(len(gt)): 73 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 74 | anomaly_state = True 75 | for j in range(i, 0, -1): 76 | if gt[j] == 0: 77 | break 78 | else: 79 | if pred[j] == 0: 80 | pred[j] = 1 81 | for j in range(i, len(gt)): 82 | if gt[j] == 0: 83 | break 84 | else: 85 | if pred[j] == 0: 86 | pred[j] = 1 87 | elif gt[i] == 0: 88 | anomaly_state = False 89 | if anomaly_state: 90 | pred[i] = 1 91 | return gt, pred 92 | 93 | 94 | def cal_accuracy(y_pred, y_true): 95 | return np.mean(y_pred == y_true) 96 | 97 | 98 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 99 | warmup_schedule = np.array([]) 100 | warmup_iters = warmup_epochs * niter_per_ep 101 | if warmup_epochs > 0: 102 | warmup_schedule = np.linspace( 103 | start_warmup_value, base_value, warmup_iters) 104 | 105 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 106 | schedule = final_value + 0.5 * \ 107 | (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 108 | 109 | schedule = np.concatenate((warmup_schedule, schedule)) 110 | assert len(schedule) == epochs * niter_per_ep 111 | return schedule 112 | 113 | 114 | class NativeScalerWithGradNormCount: 115 | # https://github.com/facebookresearch/mae/blob/main/util/misc.py 116 | state_dict_key = "amp_scaler" 117 | 118 | def __init__(self): 119 | self._scaler = torch.cuda.amp.GradScaler() 120 | 121 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 122 | self._scaler.scale(loss).backward(create_graph=create_graph) 123 | if update_grad: 124 | if clip_grad is not None: 125 | assert parameters is not None 126 | # unscale the gradients of optimizer's assigned params in-place 127 | self._scaler.unscale_(optimizer) 128 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 129 | else: 130 | self._scaler.unscale_(optimizer) 131 | norm = get_grad_norm_(parameters) 132 | self._scaler.step(optimizer) 133 | self._scaler.update() 134 | else: 135 | norm = None 136 | return norm 137 | 138 | def state_dict(self): 139 | return self._scaler.state_dict() 140 | 141 | def load_state_dict(self, state_dict): 142 | self._scaler.load_state_dict(state_dict) 143 | 144 | 145 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 146 | if isinstance(parameters, torch.Tensor): 147 | parameters = [parameters] 148 | parameters = [p for p in parameters if p.grad is not None] 149 | norm_type = float(norm_type) 150 | if len(parameters) == 0: 151 | return torch.tensor(0.) 152 | device = parameters[0].grad.device 153 | if norm_type == inf: 154 | total_norm = max(p.grad.detach().abs().max().to(device) 155 | for p in parameters) 156 | else: 157 | total_norm = torch.norm(torch.stack([torch.norm( 158 | p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 159 | return total_norm 160 | 161 | 162 | def check_cuda_memory(): 163 | """ 164 | Check and print the current GPU memory usage in PyTorch. 165 | """ 166 | if torch.cuda.is_available(): 167 | current_device = torch.cuda.current_device() 168 | gpu_name = torch.cuda.get_device_name(current_device) 169 | total_memory = torch.cuda.get_device_properties( 170 | current_device).total_memory 171 | allocated_memory = torch.cuda.memory_allocated(current_device) 172 | cached_memory = torch.cuda.memory_reserved(current_device) 173 | 174 | print(f"GPU: {gpu_name}") 175 | print(f"Total Memory: {total_memory / 1e9:.5f} GB") 176 | print(f"Allocated Memory: {allocated_memory / 1e9:.5f} GB") 177 | print(f"Cached Memory: {cached_memory / 1e9:.5f} GB") 178 | else: 179 | print("CUDA is not available.") 180 | --------------------------------------------------------------------------------