├── figures ├── tsde.jpg ├── tsde.pdf ├── clustering.jpg ├── embed_viz.png ├── tsde_logo.png ├── pred_example.jpg └── forecasting_pred.pdf ├── requirements.txt ├── CITATION.md ├── LICENSE ├── src ├── config │ ├── base_ad.yaml │ ├── base_forecasting.yaml │ ├── base_classification.yaml │ └── base.yaml ├── base │ ├── diffEmbedding.py │ ├── denoisingNetwork.py │ └── mtsEmbedding.py ├── experiments │ ├── train_test_anomaly_detection.py │ ├── train_test_tslib_forecasting.py │ ├── train_test_interpolation.py │ ├── train_test_imputation.py │ ├── train_test_classification.py │ └── train_test_forecasting.py ├── utils │ ├── metrics.py │ ├── download_data.py │ ├── masking_strategies.py │ └── utils.py ├── data_loader │ ├── elec_tslib_dataloader.py │ ├── pm25_dataloader.py │ ├── physio_dataloader.py │ ├── forecasting_dataloader.py │ └── anomaly_detection_dataloader.py └── tsde │ └── main_model.py ├── .gitignore └── README.md /figures/tsde.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/tsde.jpg -------------------------------------------------------------------------------- /figures/tsde.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/tsde.pdf -------------------------------------------------------------------------------- /figures/clustering.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/clustering.jpg -------------------------------------------------------------------------------- /figures/embed_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/embed_viz.png -------------------------------------------------------------------------------- /figures/tsde_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/tsde_logo.png -------------------------------------------------------------------------------- /figures/pred_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/pred_example.jpg -------------------------------------------------------------------------------- /figures/forecasting_pred.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/forecasting_pred.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | pandas 4 | requests 5 | scikit-learn 6 | scipy 7 | tqdm 8 | wget # Note: 'wget' package might be named differently in Conda 9 | --extra-index-url https://download.pytorch.org/whl/cu118 ## if cuda is available if not comment to use cpu 10 | torch==2.2.1+cu118 11 | pyyaml 12 | gdown -------------------------------------------------------------------------------- /CITATION.md: -------------------------------------------------------------------------------- 1 | # Citation 2 | 3 | If you use or refer to this repository in your research, please cite our paper: 4 | 5 | ### BibTeX 6 | ```bash 7 | @inproceedings{senane2024tsde, 8 | title={{Self-Supervised Learning of Time Series Representation via Diffusion Process and Imputation-Interpolation-Forecasting Mask}}, 9 | author={Senane, Zineb and Cao, Lele and Buchner, Valentin Leonhard and Tashiro, Yusuke and You, Lei and Herman, Pawel and Nordahl, Mats and Tu, Ruibo and von Ehrenheim, Vilhelm}, 10 | booktitle={to appear In KDD 24: Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining.}, 11 | year={2024} 12 | } 13 | ``` 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 EQT Partners AB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/config/base_ad.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 250 ## Pre-raining epochs using IIF masking 5 | batch_size: 32 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 25 | featureemb: 16 26 | num_feat: 51 ## Total number of features in the MTS (K) 27 | num_timestamps: 100 ## Total number of timestamps in the MTS (L) 28 | classes: 2 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 30 ## Number of finetuning epochs for the downstream task 34 | 35 | -------------------------------------------------------------------------------- /src/config/base_forecasting.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 400 ## Pre-raining epochs using IIF masking 5 | batch_size: 8 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 25 | featureemb: 16 26 | num_feat: 35 ## Total number of features in the MTS (K) 27 | num_timestamps: 48 ## Total number of timestamps in the MTS (L) 28 | classes: 1 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 20 ## Number of finetuning epochs for the downstream task 34 | -------------------------------------------------------------------------------- /src/config/base_classification.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 500 ## Pre-raining epochs using IIF masking 5 | batch_size: 4 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 25 | featureemb: 16 26 | num_feat: 35 ## Total number of features in the MTS (K) 27 | num_timestamps: 48 ## Total number of timestamps in the MTS (L) 28 | classes: 2 ## Number of classes 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 50 ## Number of finetuning epochs for the downstream task 34 | 35 | -------------------------------------------------------------------------------- /src/config/base.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 1500 ## Pre-raining epochs using IIF masking 5 | batch_size: 16 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 ## Time embedding dimension, needed as parameter for the embedding block, same as the one model 25 | featureemb: 16 ## Feature embedding dimension, needed as parameter for the embedding block, same as the one model 26 | num_feat: 35 ## Total number of features in the MTS (K) 27 | num_timestamps: 48 ## Total number of timestamps in the MTS (L) 28 | classes: 1 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 100 ## Number of finetuning epochs for the downstream task 34 | -------------------------------------------------------------------------------- /src/base/diffEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DiffusionEmbedding(nn.Module): 6 | """ 7 | A neural network module for creating embeddings for diffusion steps. This module is designed to encode the diffusion 8 | steps into a continuous space for the reverse pass of the DDPM (denoising block in TSDE). 9 | 10 | Parameters: 11 | - num_steps (int): The number of diffusion steps or time steps to be encoded (T=50). 12 | - embedding_dim (int, optional): The dimensionality of the embedding space. (we set it to 128). 13 | - projection_dim (int, optional): The dimensionality of the projected embedding space. If not specified, 14 | it defaults to the same as `embedding_dim`. 15 | 16 | The embedding for a given diffusion step is produced by first generating a sinusoidal embedding of the step, 17 | followed by projecting this embedding through two linear layers with SiLU (Sigmoid Linear Unit) and ReLU 18 | activations, respectively. 19 | """ 20 | 21 | def __init__(self, num_steps, embedding_dim=128, projection_dim=None): 22 | super().__init__() 23 | if projection_dim is None: 24 | projection_dim = embedding_dim 25 | self.register_buffer( 26 | "embedding", 27 | self._build_embedding(num_steps, embedding_dim / 2), 28 | persistent=False, 29 | ) 30 | self.projection1 = nn.Linear(embedding_dim, projection_dim) 31 | self.projection2 = nn.Linear(projection_dim, projection_dim) 32 | 33 | def forward(self, diffusion_step): 34 | """ 35 | Defines the forward pass for projecting the diffusion embedding of a specific diffusion step t. 36 | 37 | Parameters: 38 | - diffusion_step: An integer indicating the diffusion step for which embeddings are generated. 39 | 40 | Returns: 41 | - Tensor: The projected embedding for the given diffusion step. 42 | """ 43 | x = self.embedding[diffusion_step] 44 | x = self.projection1(x) 45 | x = F.silu(x) 46 | x = self.projection2(x) 47 | x = F.relu(x) 48 | return x 49 | 50 | def _build_embedding(self, num_steps, dim=64): 51 | """ 52 | Builds the sinusoidal embedding table for diffusion steps as in CSDI (https://arxiv.org/pdf/2107.03502.pdf). 53 | 54 | Parameters: 55 | - num_steps (int): The number of diffusion steps to encode (T=50). 56 | - dim (int): The dimensionality of the sinusoidal embedding before doubling (due to sin and cos). 57 | 58 | Returns: 59 | - Tensor: A tensor of shape (num_steps, embedding_dim) containing the sinusoidal embeddings of all diffusion steps. 60 | """ 61 | steps = torch.arange(num_steps).unsqueeze(1) # (T,1) 62 | frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(0) # (1,dim) 63 | table = steps * frequencies # (T,dim) 64 | table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2) 65 | return table -------------------------------------------------------------------------------- /src/experiments/train_test_anomaly_detection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | import random 8 | import numpy as np 9 | import torch.nn as nn 10 | 11 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 12 | sys.path.append(root_dir) 13 | 14 | from data_loader.anomaly_detection_dataloader import anomaly_detection_dataloader 15 | from tsde.main_model import TSDE_AD 16 | from utils.utils import train, finetune, evaluate_finetuning, gsutil_cp, set_seed 17 | 18 | torch.backends.cudnn.enabled = False 19 | 20 | parser = argparse.ArgumentParser(description="TSDE-Anomaly Detection") 21 | parser.add_argument("--config", type=str, default="base_ad.yaml") 22 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 23 | parser.add_argument("--seed", type=int, default=1) 24 | parser.add_argument("--disable_finetune", action="store_true") 25 | 26 | 27 | parser.add_argument("--dataset", type=str, default='SMAP') 28 | parser.add_argument("--modelfolder", type=str, default="") 29 | parser.add_argument("--run", type=int, default=1) 30 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 31 | parser.add_argument("--anomaly_ratio", type=float, default=1, help="Anomaly ratio") 32 | args = parser.parse_args() 33 | print(args) 34 | 35 | path = "src/config/" + args.config 36 | with open(path, "r") as f: 37 | config = yaml.safe_load(f) 38 | 39 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 40 | 41 | print(json.dumps(config, indent=4)) 42 | 43 | set_seed(args.seed) 44 | 45 | 46 | foldername = "./save/Anomaly_Detection/" + args.dataset + "/run_" + str(args.run) +"/" 47 | model = TSDE_AD(target_dim = config["embedding"]["num_feat"], config = config, device = args.device).to(args.device) 48 | train_loader, valid_loader, test_loader = anomaly_detection_dataloader(dataset_name = args.dataset, batch_size = config["train"]["batch_size"]) 49 | anomaly_ratio = args.anomaly_ratio 50 | 51 | print('model folder:', foldername) 52 | os.makedirs(foldername, exist_ok=True) 53 | 54 | with open(foldername + "config.json", "w") as f: 55 | json.dump(config, f, indent=4) 56 | 57 | if args.modelfolder == "": 58 | loss_path = foldername + "/losses.txt" 59 | with open(loss_path, "a") as file: 60 | file.write("Pretraining"+"\n") 61 | ## Pre-training 62 | train(model, config["train"], train_loader, foldername=foldername, normalize_for_ad=True) 63 | else: 64 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 65 | 66 | 67 | if not args.disable_finetune: 68 | for param in model.parameters(): 69 | param.requires_grad = False 70 | for param in model.conv.parameters(): 71 | param.requires_grad = True 72 | 73 | for name, param in model.named_parameters(): 74 | print(f"{name}: {param.requires_grad}") 75 | 76 | finetune(model, config["finetuning"], train_loader, criterion = nn.MSELoss(), foldername=foldername, task='anomaly_detection', normalize_for_ad=True) 77 | evaluate_finetuning(model, train_loader, test_loader, anomaly_ratio = anomaly_ratio, foldername=foldername, task='anomaly_detection', normalize_for_ad=True) 78 | -------------------------------------------------------------------------------- /src/experiments/train_test_tslib_forecasting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | 8 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 9 | sys.path.append(root_dir) 10 | 11 | # Import necessary modules from your project 12 | from data_loader.elec_tslib_dataloader import get_dataloader_elec 13 | from tsde.main_model import TSDE_Forecasting 14 | from utils.utils import train, evaluate, gsutil_cp, set_seed 15 | 16 | def main(): 17 | # Command line arguments for configuration, could be expanded as needed 18 | parser = argparse.ArgumentParser(description="Run forecasting model") 19 | parser.add_argument('--config', type=str, default='base_forecasting.yaml', help='Path to configuration yaml') 20 | parser.add_argument('--device', type=str, default='cuda:1', help='Device to run the model on') 21 | parser.add_argument('--nsample', type=int, default=100, help='Number of samples') 22 | parser.add_argument('--hist_length', type=int, default=96, help='History window length') 23 | parser.add_argument('--pred_length', type=int, default=192, help='Prediction window length') 24 | parser.add_argument('--run', type=int, default=100200000, help='Run identifier') 25 | parser.add_argument('--linear', action='store_true', help='Linear mode flag') 26 | parser.add_argument('--sample_feat', action='store_true', help='Sample feature flag') 27 | parser.add_argument('--seed', type=int, default=1, help='Seed for random number generation') 28 | parser.add_argument('--load', type=str, default=None, help='Path to pretrained model') 29 | args = parser.parse_args() 30 | 31 | # Set seed for reproducibility 32 | set_seed(args.seed) 33 | 34 | # Load config 35 | path = "src/config/" + args.config 36 | with open(path, "r") as f: 37 | config = yaml.safe_load(f) 38 | 39 | # Setup model folder path 40 | foldername = f"./save/Forecasting/TSLIB_Elec/n_samples_{args.nsample}_run_{args.run}_linear_{args.linear}_sample_feat_{args.sample_feat}/" 41 | os.makedirs(foldername, exist_ok=True) 42 | 43 | # Save configuration to the model folder 44 | with open(os.path.join(foldername, "config.json"), "w") as f: 45 | json.dump(config, f, indent=4) 46 | 47 | # Model setup 48 | model = TSDE_Forecasting(config, args.device, target_dim=321, sample_feat=args.sample_feat).to(args.device) 49 | train_loader, valid_loader, test_loader, _ = get_dataloader_elec( 50 | pred_length=args.pred_length, 51 | history_length=args.hist_length, 52 | batch_size=config["train"]["batch_size"], 53 | device=args.device, 54 | ) 55 | if args.load is None: 56 | # Start training 57 | train( 58 | model, 59 | config["train"], 60 | train_loader, 61 | valid_loader=valid_loader, 62 | test_loader=test_loader, 63 | foldername=foldername, 64 | nsample=args.nsample, 65 | scaler=1, 66 | mean_scaler=0, 67 | eval_epoch_interval=20000, 68 | ) 69 | else: 70 | model.load_state_dict(torch.load("./save/" + args.load + "/model.pth", map_location=args.device)) 71 | 72 | evaluate(model, test_loader, nsample=args.nsample, foldername=foldername) 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Data 156 | data 157 | 158 | # Results 159 | save 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /src/experiments/train_test_interpolation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | 8 | 9 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 10 | sys.path.append(root_dir) 11 | 12 | 13 | from data_loader.physio_dataloader import get_dataloader_physio 14 | from tsde.main_model import TSDE_Physio 15 | from utils.utils import train, evaluate, gsutil_cp, set_seed 16 | 17 | torch.backends.cudnn.enabled = False 18 | 19 | parser = argparse.ArgumentParser(description="TSDE-Interpolation") 20 | parser.add_argument("--config", type=str, default="base.yaml") 21 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 22 | parser.add_argument("--seed", type=int, default=1) 23 | 24 | 25 | parser.add_argument("--dataset", type=str, default='PhysioNet') 26 | parser.add_argument("--modelfolder", type=str, default="") 27 | parser.add_argument("--nsample", type=int, default=100) 28 | parser.add_argument("--run", type=int, default=1) 29 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 30 | 31 | 32 | ## Args for physio 33 | parser.add_argument("--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])") 34 | parser.add_argument("--testmissingratio", type=float, default=0.1) 35 | parser.add_argument("--physionet_classification", type=bool, default=False) 36 | 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | path = "src/config/" + args.config 41 | with open(path, "r") as f: 42 | config = yaml.safe_load(f) 43 | 44 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 45 | config["model"]["test_missing_ratio"] = args.testmissingratio 46 | 47 | print(json.dumps(config, indent=4)) 48 | 49 | set_seed(args.seed) 50 | 51 | if args.dataset == "PhysioNet": 52 | foldername = "./save/Interpolation/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_missing_ratio_' + str(args.testmissingratio) +"/" 53 | model = TSDE_Physio(config, args.device).to(args.device) 54 | train_loader, valid_loader, test_loader = get_dataloader_physio(seed=args.seed, nfold=args.nfold, batch_size=config["train"]["batch_size"], missing_ratio=config["model"]["test_missing_ratio"], mode='interpolation') 55 | scaler = 1 56 | mean_scaler = 0 57 | mode = 'Interpolation' 58 | else: 59 | print() 60 | 61 | print('model folder:', foldername) 62 | os.makedirs(foldername, exist_ok=True) 63 | 64 | with open(foldername + "config.json", "w") as f: 65 | json.dump(config, f, indent=4) 66 | 67 | if args.modelfolder == "": 68 | loss_path = foldername + "/losses.txt" 69 | with open(loss_path, "a") as file: 70 | file.write("Pretraining"+"\n") 71 | ## Pre-training 72 | train(model, config["train"], train_loader, valid_loader=valid_loader, test_loader=test_loader, foldername=foldername, nsample=args.nsample, 73 | scaler=scaler, mean_scaler=mean_scaler, eval_epoch_interval=500,physionet_classification=args.physionet_classification) 74 | if config["finetuning"]["epochs"]!=0: 75 | print('Finetuning') 76 | ## Fine Tuning 77 | with open(loss_path, "a") as file: 78 | file.write("Finetuning"+"\n") 79 | checkpoint_path = foldername + "model.pth" 80 | model.load_state_dict(torch.load(checkpoint_path)) 81 | config["train"]["epochs"]=config["finetuning"]["epochs"] 82 | train( 83 | model, 84 | config["train"], 85 | train_loader, 86 | valid_loader=valid_loader, 87 | foldername=foldername, 88 | mode = mode, 89 | ) 90 | 91 | else: 92 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 93 | 94 | evaluate( 95 | model, 96 | test_loader, 97 | nsample=args.nsample, 98 | scaler=scaler, 99 | mean_scaler=mean_scaler, 100 | foldername=foldername, 101 | save_samples = True, 102 | physionet_classification=args.physionet_classification 103 | 104 | ) 105 | 106 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import auc 4 | 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def quantile_loss(target, forecast, q: float, eval_points) -> float: 10 | """ 11 | Calculates the quantile loss for a given forecast and target values based on a specific quantile. 12 | 13 | Parameters: 14 | - target: Torch tensor containing the observed values. 15 | - forecast: Torch tensor containing the predicted values. 16 | - q: Float, representing the quantile for which the loss is calculated. 17 | - eval_points: Torch tensor representing the evaluation points for the calculation. 18 | 19 | Returns: 20 | - Float representing the calculated quantile loss. 21 | """ 22 | return 2 * torch.sum( 23 | torch.abs((forecast - target) * eval_points * ((target <= forecast) * 1.0 - q)) 24 | ) 25 | 26 | 27 | def calc_denominator(target, eval_points): 28 | """ 29 | Calculates the denominator used in the CRPS and CRPS-sum calculation, based on the absolute sum of the target values used for evaluation. 30 | 31 | Parameters: 32 | - target: Torch tensor containing the target values. 33 | - eval_points: Torch tensor representing the evaluation points for the calculation. 34 | 35 | Returns: 36 | - Torch tensor representing the denominator value. 37 | """ 38 | return torch.sum(torch.abs(target * eval_points)) 39 | 40 | 41 | def calc_quantile_CRPS(target, forecast, eval_points, mean_scaler, scaler): 42 | """ 43 | Calculates the CRPS based on quantile loss for multiple quantiles. 44 | 45 | Parameters: 46 | - target: Torch tensor containing the target values. 47 | - forecast: Torch tensor containing the predicted values. 48 | - eval_points: Torch tensor representing the evaluation points for the calculation. 49 | - mean_scaler: Float, the mean value used for scaling the target and forecast back to their original values. 50 | - scaler: Float, the scale value used for scaling the target and forecast back to their original values. 51 | 52 | Returns: 53 | - Float representing the calculated CRPS. 54 | """ 55 | target = target * scaler + mean_scaler 56 | forecast = forecast * scaler + mean_scaler 57 | 58 | quantiles = np.arange(0.05, 1.0, 0.05) 59 | denom = calc_denominator(target, eval_points) 60 | CRPS = 0 61 | for i in range(len(quantiles)): 62 | q_pred = [] 63 | for j in range(len(forecast)): 64 | q_pred.append(torch.quantile(forecast[j : j + 1], quantiles[i], dim=1)) 65 | q_pred = torch.cat(q_pred, 0) 66 | q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points) 67 | CRPS += q_loss / denom 68 | return CRPS.item() / len(quantiles) 69 | 70 | def calc_quantile_CRPS_sum(target, forecast, eval_points, mean_scaler, scaler): 71 | """ 72 | Calculates the CRPS for the sum of the target and predicted values across all features. 73 | 74 | Parameters: 75 | - target: Torch tensor containing the target values. 76 | - forecast: Torch tensor containing the predicted values. 77 | - eval_points: Torch tensor representing the evaluation points for the calculation. 78 | - mean_scaler: Float, the mean value used for scaling the target and predictions back to their original values. 79 | - scaler: Float, the scale value used for scaling the target and predictions back to their original values. 80 | 81 | Returns: 82 | - Float representing the calculated CRPS-sum. 83 | """ 84 | target = target * scaler + mean_scaler 85 | forecast = forecast * scaler + mean_scaler 86 | target_sum = torch.sum(target, dim = 2).unsqueeze(2) 87 | forecast_sum = torch.sum(forecast, dim = 3).unsqueeze(3) 88 | eval_points_sum = torch.mean(eval_points, dim=2).unsqueeze(2) 89 | 90 | crps_sum = calc_quantile_CRPS(target_sum, forecast_sum, eval_points_sum, 0, 1) 91 | return crps_sum 92 | 93 | 94 | def save_roc_curve(fpr, tpr, foldername): 95 | """ 96 | Generates and saves an ROC curve to the specified file path. 97 | 98 | Parameters: 99 | y_true (array-like): True binary labels. 100 | y_scores (array-like): Scores assigned by the classifier. 101 | file_path (str): Path where the ROC curve image will be saved. 102 | 103 | Returns: 104 | str: The file path where the image was saved. 105 | """ 106 | 107 | roc_auc = auc(fpr, tpr) 108 | 109 | # Plotting and saving the ROC curve 110 | plt.figure() 111 | plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') 112 | plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') 113 | plt.xlim([0.0, 1.0]) 114 | plt.ylim([0.0, 1.05]) 115 | plt.xlabel('False Positive Rate') 116 | plt.ylabel('True Positive Rate') 117 | plt.title('Receiver Operating Characteristic') 118 | plt.legend(loc="lower right") 119 | plt.savefig(foldername+'roc.png') 120 | plt.close() -------------------------------------------------------------------------------- /src/experiments/train_test_imputation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | import numpy as np 8 | from multiprocessing import freeze_support 9 | 10 | 11 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 12 | sys.path.append(root_dir) 13 | 14 | from data_loader.pm25_dataloader import get_dataloader_pm25 15 | from data_loader.physio_dataloader import get_dataloader_physio 16 | from tsde.main_model import TSDE_PM25, TSDE_Physio 17 | from utils.utils import train, evaluate, gsutil_cp, set_seed 18 | 19 | torch.backends.cudnn.enabled = False 20 | 21 | parser = argparse.ArgumentParser(description="TSDE-Imputation") 22 | parser.add_argument("--config", type=str, default="base.yaml") 23 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 24 | parser.add_argument("--seed", type=int, default=1) 25 | 26 | parser.add_argument("--dataset", type=str, default='Pm25') 27 | parser.add_argument("--modelfolder", type=str, default="") 28 | parser.add_argument("--nsample", type=int, default=100) 29 | parser.add_argument("--run", type=int, default=0) 30 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 31 | 32 | 33 | ## Args for pm25 34 | parser.add_argument("--validationindex", type=int, default=0, help="index of month used for validation (value:[0-7])") 35 | ## Args for physio 36 | parser.add_argument("--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])") 37 | parser.add_argument("--testmissingratio", type=float, default=0.1) 38 | parser.add_argument("--physionet_classification", type=bool, default=False) 39 | 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | path = "src/config/" + args.config 44 | with open(path, "r") as f: 45 | config = yaml.safe_load(f) 46 | 47 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 48 | config["model"]["test_missing_ratio"] = args.testmissingratio 49 | 50 | print(json.dumps(config, indent=4)) 51 | 52 | set_seed(args.seed) 53 | 54 | 55 | if args.dataset == "Pm25": 56 | foldername = "./save/Imputation/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_validationindex_' + str(args.validationindex) +"/" 57 | model = TSDE_PM25(config, args.device).to(args.device) 58 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_pm25(config["train"]["batch_size"], device=args.device, validindex=args.validationindex) 59 | mode = 'Imputation with pattern' 60 | 61 | elif args.dataset == "PhysioNet": 62 | foldername = "./save/Imputation/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_missing_ratio_' + str(args.testmissingratio) +"/" 63 | model = TSDE_Physio(config, args.device).to(args.device) 64 | train_loader, valid_loader, test_loader = get_dataloader_physio(seed=args.seed, nfold=args.nfold, batch_size=config["train"]["batch_size"], missing_ratio=config["model"]["test_missing_ratio"]) 65 | scaler = 1 66 | mean_scaler = 0 67 | mode = 'Imputation' 68 | else: 69 | print() 70 | 71 | print('model folder:', foldername) 72 | os.makedirs(foldername, exist_ok=True) 73 | 74 | with open(foldername + "config.json", "w") as f: 75 | json.dump(config, f, indent=4) 76 | 77 | 78 | 79 | def main(): 80 | if args.modelfolder == "": 81 | loss_path = foldername + "/losses.txt" 82 | with open(loss_path, "a") as file: 83 | file.write("Pretraining"+"\n") 84 | 85 | ## Pre-training 86 | train(model, config["train"], train_loader, valid_loader=valid_loader, test_loader=test_loader, foldername=foldername, nsample=args.nsample, 87 | scaler=scaler, mean_scaler=mean_scaler, eval_epoch_interval=500,physionet_classification=args.physionet_classification) 88 | 89 | if config["finetuning"]["epochs"]!=0: 90 | print('Finetuning') 91 | ## Fine Tuning 92 | with open(loss_path, "a") as file: 93 | file.write("Finetuning"+"\n") 94 | checkpoint_path = foldername + "model.pth" 95 | model.load_state_dict(torch.load(checkpoint_path)) 96 | config["train"]["epochs"]=config["finetuning"]["epochs"] 97 | train( 98 | model, 99 | config["train"], 100 | train_loader, 101 | valid_loader=valid_loader, 102 | foldername=foldername, 103 | mode = mode, 104 | ) 105 | 106 | else: 107 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 108 | 109 | evaluate( 110 | model, 111 | test_loader, 112 | nsample=args.nsample, 113 | scaler=scaler, 114 | mean_scaler=mean_scaler, 115 | foldername=foldername, 116 | save_samples = True, 117 | physionet_classification=args.physionet_classification 118 | ) 119 | 120 | 121 | 122 | 123 | if __name__ == '__main__': 124 | #freeze_support() # Recommended for Windows if you plan to freeze your script 125 | main() 126 | -------------------------------------------------------------------------------- /src/experiments/train_test_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | import torch.nn as nn 8 | 9 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 10 | sys.path.append(root_dir) 11 | 12 | from data_loader.physio_dataloader import get_dataloader_physio, get_physio_dataloader_for_classification 13 | from tsde.main_model import TSDE_Physio 14 | from utils.utils import train, evaluate, gsutil_cp, set_seed, finetune, evaluate_finetuning 15 | 16 | torch.backends.cudnn.enabled = False 17 | 18 | parser = argparse.ArgumentParser(description="TSDE-Imputation-Classification") 19 | parser.add_argument("--config", type=str, default="base_classification.yaml") 20 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 21 | parser.add_argument("--seed", type=int, default=1) 22 | 23 | parser.add_argument("--modelfolder", type=str, default="") 24 | parser.add_argument("--nsample", type=int, default=100) 25 | parser.add_argument("--run", type=int, default=1) 26 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal p or probabilistic layering)") 27 | parser.add_argument("--disable_finetune", action="store_true") 28 | 29 | ## Args for physio 30 | parser.add_argument("--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])") 31 | parser.add_argument("--testmissingratio", type=float, default=0.1) 32 | parser.add_argument("--physionet_classification", type=bool, default=True) 33 | 34 | args = parser.parse_args() 35 | print(args) 36 | 37 | path = "src/config/" + args.config 38 | with open(path, "r") as f: 39 | config = yaml.safe_load(f) 40 | 41 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 42 | config["model"]["test_missing_ratio"] = args.testmissingratio 43 | 44 | print(json.dumps(config, indent=4)) 45 | 46 | set_seed(args.seed) 47 | 48 | 49 | foldername = "./save/Imputation-Classification/" + 'PhysioNet' + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_missing_ratio_' + str(args.testmissingratio) +"/" 50 | model = TSDE_Physio(config, args.device).to(args.device) 51 | train_loader, valid_loader, test_loader = get_dataloader_physio(seed=args.seed, nfold=args.nfold, batch_size=config["train"]["batch_size"], missing_ratio=config["model"]["test_missing_ratio"]) 52 | scaler = 1 53 | mean_scaler = 0 54 | mode = 'Imputation' 55 | 56 | 57 | print('model folder:', foldername) 58 | os.makedirs(foldername, exist_ok=True) 59 | 60 | with open(foldername + "config.json", "w") as f: 61 | json.dump(config, f, indent=4) 62 | 63 | if args.modelfolder == "": 64 | os.makedirs(foldername+'Pretrained/', exist_ok=True) 65 | loss_path = foldername + "losses.txt" 66 | with open(loss_path, "a") as file: 67 | file.write("Pretraining"+"\n") 68 | ## Pre-training 69 | train(model, config["train"], train_loader, valid_loader=valid_loader, test_loader=test_loader, foldername=foldername+'Pretrained/', nsample=args.nsample, 70 | scaler=scaler, mean_scaler=mean_scaler, eval_epoch_interval=100000,physionet_classification=args.physionet_classification) 71 | 72 | ## Save imputed time series in Train, Validation and Test sets 73 | evaluate( 74 | model, 75 | train_loader, 76 | nsample=args.nsample, 77 | scaler=scaler, 78 | mean_scaler=mean_scaler, 79 | foldername=foldername+'Pretrained/', 80 | save_samples = True, 81 | physionet_classification=True, 82 | set_type = 'Train' 83 | ) 84 | 85 | evaluate( 86 | model, 87 | valid_loader, 88 | nsample=args.nsample, 89 | scaler=scaler, 90 | mean_scaler=mean_scaler, 91 | foldername=foldername+'Pretrained/', 92 | save_samples = True, 93 | physionet_classification=True, 94 | set_type = 'Val' 95 | ) 96 | evaluate( 97 | model, 98 | test_loader, 99 | nsample=args.nsample, 100 | scaler=scaler, 101 | mean_scaler=mean_scaler, 102 | foldername=foldername+'Pretrained/', 103 | save_samples = True, 104 | physionet_classification=True, 105 | set_type = 'Test' 106 | ) 107 | 108 | 109 | ## Prepare dataloaders for classification head finetuning 110 | train_loader_classification, valid_loader_classification, test_loader_classification = get_physio_dataloader_for_classification(filename=foldername+'Pretrained/', batch_size=config["train"]["batch_size"]) 111 | model.load_state_dict(torch.load(foldername + "Pretrained/model.pth", map_location=args.device)) 112 | 113 | 114 | else: 115 | # Load pretrained and dataloaders of imputed MTS 116 | train_loader_classification, valid_loader_classification, test_loader_classification = get_physio_dataloader_for_classification(filename=args.modelfolder+"Pretrained/", batch_size=config["train"]["batch_size"]) 117 | model.load_state_dict(torch.load(args.modelfolder + "Pretrained/model.pth", map_location=args.device)) 118 | 119 | print(args.disable_finetune) 120 | if not args.disable_finetune: 121 | for name, param in model.named_parameters(): 122 | # Freeze all parameters 123 | param.requires_grad = False 124 | 125 | 126 | for name, param in model.mlp.named_parameters(): 127 | param.requires_grad = True 128 | for name, param in model.named_parameters(): 129 | print(f"{name}: {param.requires_grad}") 130 | 131 | ## Finetune the classifier head 132 | finetune(model, config["finetuning"], train_loader_classification, criterion = nn.CrossEntropyLoss(), foldername=foldername) 133 | else: 134 | model.load_state_dict(torch.load(args.modelfolder + "model.pth", map_location=args.device)) 135 | 136 | 137 | ## Evaluate the classification 138 | evaluate_finetuning(model, train_loader_classification, test_loader_classification, foldername=foldername) 139 | 140 | 141 | -------------------------------------------------------------------------------- /src/base/denoisingNetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from base.diffEmbedding import DiffusionEmbedding 7 | 8 | 9 | def Conv1d_with_init(in_channels, out_channels, kernel_size): 10 | """ 11 | Initializes a 1D convolutional layer with Kaiming normal initialization. 12 | 13 | Parameters: 14 | - in_channels (int): Number of channels in the input signal. 15 | - out_channels (int): Number of channels produced by the convolution. 16 | - kernel_size (int): Size of the convolving kernel. 17 | 18 | Returns: 19 | - nn.Conv1d: A 1D convolutional layer with weights initialized. 20 | """ 21 | layer = nn.Conv1d(in_channels, out_channels, kernel_size) 22 | nn.init.kaiming_normal_(layer.weight) 23 | return layer 24 | 25 | 26 | class diff_Block(nn.Module): 27 | """ 28 | A neural network block that incorporates diffusion embedding, designed for the reverse pass in DDPM. It corresponds to the denoising block in the TSDE architecture. 29 | 30 | Parameters: 31 | - config (dict): Configuration dictionary containing model settings for the denoising block. 32 | """ 33 | def __init__(self, config): 34 | super().__init__() 35 | 36 | self.channels = config["channels"] 37 | 38 | self.diffusion_embedding = DiffusionEmbedding( 39 | num_steps=config["num_steps"], 40 | embedding_dim=config["diffusion_embedding_dim"], 41 | ) 42 | 43 | self.input_projection = Conv1d_with_init(1, self.channels, 1) 44 | self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1) 45 | self.output_projection2 = Conv1d_with_init(self.channels, 1, 1) 46 | nn.init.zeros_(self.output_projection2.weight) 47 | 48 | self.residual_layers = nn.ModuleList( 49 | [ 50 | ResidualBlock( 51 | mts_emb_dim=config["mts_emb_dim"], 52 | channels=self.channels, 53 | diffusion_embedding_dim=config["diffusion_embedding_dim"], 54 | ) 55 | for _ in range(config["layers"]) 56 | ] 57 | ) 58 | 59 | def forward(self, x, mts_emb, diffusion_step): 60 | """ 61 | Forward pass of the denoising block. 62 | 63 | Parameters: 64 | - x (Tensor): The corrupted input MTS to be denoised. 65 | - mts_emb (Tensor): The embedding of the observed part of the MTS. 66 | - diffusion_step (Tensor): The current diffusion step index. 67 | 68 | Returns: 69 | - Tensor: The output tensor of the predicted noise added in x at diffusion_step. 70 | """ 71 | B, inputdim, K, L = x.shape 72 | 73 | x = x.reshape(B, inputdim, K * L) 74 | x = self.input_projection(x) ## First Convolution before fedding the data to the 75 | x = F.relu(x) ## residual block 76 | x = x.reshape(B, self.channels, K, L) 77 | 78 | diffusion_emb = self.diffusion_embedding(diffusion_step) 79 | 80 | skip = [] 81 | for layer in self.residual_layers: 82 | x, skip_connection = layer(x, mts_emb, diffusion_emb) 83 | skip.append(skip_connection) 84 | 85 | x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) 86 | x = x.reshape(B, self.channels, K * L) 87 | x = self.output_projection1(x) # (B,channel,K*L) 88 | x = F.relu(x) 89 | x = self.output_projection2(x) # (B,1,K*L) 90 | x = x.reshape(B, K, L) 91 | return x 92 | 93 | 94 | class ResidualBlock(nn.Module): 95 | """ 96 | A residual block that processes input data alongside diffusion embeddings and observed MTS part embedding, 97 | utilizing a gated mechanism. 98 | 99 | Parameters: 100 | - mts_emb_dim (int): Dimensionality of the embedding of the MTS 101 | - channels (int): Number of channels for the convolutional layers within the block. 102 | - diffusion_embedding_dim (int): Dimensionality of the diffusion embeddings. 103 | 104 | """ 105 | def __init__(self, mts_emb_dim, channels, diffusion_embedding_dim): 106 | super().__init__() 107 | self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels) 108 | self.cond_projection = Conv1d_with_init(mts_emb_dim, 2 * channels, 1) 109 | self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1) 110 | self.output_projection = Conv1d_with_init(channels, 2 * channels, 1) 111 | 112 | 113 | def forward(self, x, mts_emb, diffusion_emb): 114 | """ 115 | Forward pass of the ResidualBlock. 116 | 117 | Parameters: 118 | - x (Tensor): The projected corrupted input MTS. 119 | - mts_emb (Tensor): The embedding of the observed part of the MTS. 120 | - diffusion_emb (Tensor): The projected diffusion embedding tensor. 121 | 122 | Returns: 123 | - Tuple[Tensor, Tensor]: A tuple containing the updated data tensor and a skip connection tensor. 124 | """ 125 | 126 | B, channel, K, L = x.shape 127 | base_shape = x.shape 128 | x = x.reshape(B, channel, K * L) 129 | 130 | diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1) # (B,channel,1) 131 | y = x + diffusion_emb 132 | 133 | y = self.mid_projection(y) # (B,2*channel,K*L) 134 | 135 | _, mts_emb_dim, _, _ = mts_emb.shape 136 | mts_emb = mts_emb.reshape(B, mts_emb_dim, K * L) #B, C, K*L 137 | mts_emb = self.cond_projection(mts_emb) # (B,2*channel,K*L) 138 | y = y + mts_emb 139 | 140 | gate, filter = torch.chunk(y, 2, dim=1) 141 | y = torch.sigmoid(gate) * torch.tanh(filter) # (B,channel,K*L) 142 | y = self.output_projection(y) 143 | 144 | residual, skip = torch.chunk(y, 2, dim=1) 145 | x = x.reshape(base_shape) 146 | residual = residual.reshape(base_shape) 147 | skip = skip.reshape(base_shape) 148 | return (x + residual) / math.sqrt(2.0), skip -------------------------------------------------------------------------------- /src/experiments/train_test_forecasting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | 8 | 9 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 10 | sys.path.append(root_dir) 11 | 12 | from data_loader.forecasting_dataloader import get_dataloader_forecasting 13 | from tsde.main_model import TSDE_Forecasting 14 | from utils.utils import train, evaluate, gsutil_cp, set_seed 15 | 16 | torch.backends.cudnn.enabled = False 17 | 18 | parser = argparse.ArgumentParser(description="TSDE-Forecasting") 19 | parser.add_argument("--config", type=str, default="base_forecasting.yaml") 20 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 21 | parser.add_argument("--seed", type=int, default=1) 22 | parser.add_argument('--linear', action='store_true', help='Linear mode flag') 23 | parser.add_argument('--sample_feat', action='store_true', help='Sample feature flag') 24 | 25 | 26 | parser.add_argument("--dataset", type=str, default='Electricity') 27 | parser.add_argument("--modelfolder", type=str, default="") 28 | parser.add_argument("--nsample", type=int, default=100) 29 | parser.add_argument("--run", type=int, default=1) 30 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 31 | 32 | args = parser.parse_args() 33 | print(args) 34 | 35 | path = "src/config/" + args.config 36 | with open(path, "r") as f: 37 | config = yaml.safe_load(f) 38 | 39 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 40 | 41 | 42 | print(json.dumps(config, indent=4)) 43 | 44 | set_seed(args.seed) 45 | 46 | if args.dataset == "Electricity": 47 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 48 | model = TSDE_Forecasting(config, args.device, target_dim=370, sample_feat=args.sample_feat).to(args.device) 49 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 50 | dataset_name='electricity', 51 | train_length=5833, 52 | skip_length=370*6, 53 | batch_size=config["train"]["batch_size"], 54 | device= args.device, 55 | ) 56 | 57 | elif args.dataset == "Solar": 58 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 59 | model = TSDE_Forecasting(config, args.device, target_dim=137, sample_feat=args.sample_feat).to(args.device) 60 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 61 | dataset_name='solar', 62 | train_length=7009, 63 | skip_length=137*6, 64 | batch_size=config["train"]["batch_size"], 65 | device= args.device, 66 | ) 67 | 68 | elif args.dataset == "Traffic": 69 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 70 | model = TSDE_Forecasting(config, args.device, target_dim=963, sample_feat=args.sample_feat).to(args.device) 71 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 72 | dataset_name='traffic', 73 | train_length=4001, 74 | skip_length=963*6, 75 | batch_size=config["train"]["batch_size"], 76 | device= args.device, 77 | ) 78 | 79 | elif args.dataset == "Taxi": 80 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 81 | model = TSDE_Forecasting(config, args.device, target_dim=1214, sample_feat=args.sample_feat).to(args.device) 82 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 83 | dataset_name='taxi', 84 | train_length=1488, 85 | skip_length=1214*55, 86 | test_length=24*56, 87 | history_length=48, 88 | batch_size=config["train"]["batch_size"], 89 | device= args.device, 90 | ) 91 | 92 | elif args.dataset == "Wiki": 93 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 94 | model = TSDE_Forecasting(config, args.device, target_dim=2000, sample_feat=args.sample_feat).to(args.device) 95 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 96 | dataset_name='wiki', 97 | train_length=792, 98 | skip_length=9535*4, 99 | test_length=30*5, 100 | valid_length=30*5, 101 | history_length=90, 102 | pred_length=30, 103 | batch_size=config["train"]["batch_size"], 104 | device= args.device, 105 | ) 106 | 107 | else: 108 | print() 109 | 110 | 111 | print('model folder:', foldername) 112 | os.makedirs(foldername, exist_ok=True) 113 | 114 | 115 | with open(foldername + "config.json", "w") as f: 116 | json.dump(config, f, indent=4) 117 | 118 | 119 | if args.modelfolder == "": 120 | loss_path = foldername + "/losses.txt" 121 | with open(loss_path, "a") as file: 122 | file.write("Pretraining"+"\n") 123 | train( 124 | model, 125 | config["train"], 126 | train_loader, 127 | valid_loader=valid_loader, 128 | test_loader=test_loader, 129 | foldername=foldername, 130 | nsample=args.nsample, 131 | scaler=scaler, 132 | mean_scaler=mean_scaler, 133 | eval_epoch_interval=200, 134 | ) 135 | if config["finetuning"]["epochs"]!=0: 136 | print("Finetuning") 137 | ## Fine Tuning 138 | with open(loss_path, "a") as file: 139 | file.write("Finetuning"+"\n") 140 | checkpoint_path = foldername + "model.pth" 141 | model.load_state_dict(torch.load(checkpoint_path)) 142 | config["train"]["epochs"]=config["finetuning"]["epochs"] 143 | train( 144 | model, 145 | config["train"], 146 | train_loader, 147 | valid_loader=valid_loader, 148 | foldername=foldername, 149 | mode = 'Forecasting', 150 | ) 151 | 152 | else: 153 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 154 | 155 | evaluate(model, test_loader, nsample=args.nsample, foldername=foldername, scaler=scaler, mean_scaler=mean_scaler,save_samples = True) 156 | 157 | -------------------------------------------------------------------------------- /src/base/mtsEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def get_torch_trans(heads=8, layers=1, channels=64): 6 | """ 7 | Creates a Transformer encoder module to process MTS timestamps/features as sequences. 8 | 9 | Parameters: 10 | - heads (int): Number of attention heads. 11 | - layers (int): Number of encoder layers. 12 | - channels (int): Dimensionality of the model (d_model in Transformer terminology). 13 | 14 | Returns: 15 | - nn.TransformerEncoder: A Transformer encoder object. 16 | """ 17 | encoder_layer = nn.TransformerEncoderLayer( 18 | d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu" 19 | ) 20 | return nn.TransformerEncoder(encoder_layer, num_layers=layers, enable_nested_tensor=False) 21 | 22 | 23 | def Conv1d_with_init(in_channels, out_channels, kernel_size): 24 | """ 25 | Initializes a 1D convolutional layer with Kaiming normal initialization. 26 | 27 | Parameters: 28 | - in_channels (int): Number of channels in the input signal. 29 | - out_channels (int): Number of channels produced by the convolution. 30 | - kernel_size (int): Size of the convolving kernel. 31 | 32 | Returns: 33 | - nn.Conv1d: A 1D convolutional layer with weights initialized. 34 | """ 35 | layer = nn.Conv1d(in_channels, out_channels, kernel_size) 36 | nn.init.kaiming_normal_(layer.weight) 37 | return layer 38 | 39 | class embedding_MTS(nn.Module): 40 | """ 41 | An embedding module for multivariate time series data, incorporating both temporal and spatial 42 | encoding via Transformer encoders and convolutional layers, it corresponds to the Embedding block in TSDE architecture. 43 | 44 | Parameters: 45 | - config (dict): A configuration dictionary containing model parameters such as number of channels, 46 | embedding dimensions, and number of heads for the Transformer encoders in the embedding block. 47 | """ 48 | 49 | 50 | def __init__(self, config): 51 | super().__init__() 52 | self.channels = config["channels"] 53 | self.timeemb = config["timeemb"] 54 | self.featureemb = config["featureemb"] 55 | self.emb_size = self.channels + self.timeemb + self.featureemb 56 | self.time_layer = get_torch_trans(heads=config["nheads"], layers=1, channels=self.emb_size) ## Temporal encoder 57 | self.feature_layer = get_torch_trans(heads=config["nheads"], layers=1, channels=self.emb_size) ## Spatial encoder 58 | 59 | self.input_projection = Conv1d_with_init(1, self.channels, 1) 60 | 61 | self.xt_projection = Conv1d_with_init(self.emb_size, self.channels, 1) 62 | self.xf_projection = Conv1d_with_init(self.emb_size, self.channels, 1) 63 | 64 | 65 | def forward_time(self, x, base_shape): 66 | """ 67 | Processes the input data through the temporal Transformer encoder (timestamps are considered as tokens). 68 | 69 | Parameters: 70 | - x (Tensor): The observed part of the MTS combined with timestamps embedding and features embedding as tensor. 71 | - base_shape (tuple): The base shape of the input tensor before reshaping. 72 | 73 | Returns: 74 | - Tensor: The temporally encoded tensor. 75 | """ 76 | B, C, K, L = base_shape 77 | if L == 1: 78 | return x 79 | x = x.permute(0, 2, 1, 3).reshape(B * K, C, L) 80 | x = self.time_layer(x.permute(2, 0, 1)).permute(1, 2, 0) 81 | x = x.reshape(B, K, C, L).permute(0, 2, 1, 3) 82 | return x 83 | 84 | def forward_feature(self, x, base_shape): 85 | """ 86 | Processes the input data through the spatial Transformer encoder (features are considered as tokens). 87 | 88 | Parameters: 89 | - x (Tensor): The observed part of the MTS combined with timestamps embedding and features embedding as tensor. 90 | - base_shape (tuple): The base shape of the input tensor before reshaping. 91 | 92 | Returns: 93 | - Tensor: The spatially encoded tensor. 94 | """ 95 | B, C, K, L = base_shape 96 | if K == 1: 97 | return x 98 | x = x.permute(0, 3, 1, 2).reshape(B * L, C, K) 99 | x = self.feature_layer(x.permute(2, 0, 1)).permute(1, 2, 0) 100 | x = x.reshape(B, L, C, K).permute(0, 2, 3, 1) 101 | return x 102 | 103 | def forward(self, x, time_embed, feature_embed): 104 | """ 105 | The forward pass of the embedding module, processing multivariate time series data with 106 | temporal and spatial embeddings. 107 | 108 | Parameters: 109 | - x (Tensor): The observed part of the MTS. 110 | - time_embed (Tensor): The time embeddings tensor. 111 | - feature_embed (Tensor): The feature embeddings tensor. 112 | 113 | Returns: 114 | - Tuple[Tensor, Tensor, Tensor]: A tuple containing the combined temporal and spatial embeddings generated by the two transformer encoders, 115 | the processed temporal embedding, and the processed spatial embedding. 116 | """ 117 | B, _, K, L = x.shape 118 | time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) 119 | feature_embed = feature_embed.unsqueeze(1).expand(-1, L, -1, -1) 120 | 121 | x = x.reshape(B, 1, K * L) 122 | x = self.input_projection(x) ## First Convolution before fedding the data to the encoders 123 | x = F.relu(x) 124 | x = x.reshape(B, self.channels, K, L) 125 | x = torch.cat([x, time_embed.permute(0, 3, 2, 1), feature_embed.permute(0, 3, 2, 1)], dim=1) 126 | 127 | base_shape = B, x.shape[1], K, L 128 | 129 | xt = self.forward_time(x, base_shape) 130 | xf = self.forward_feature(x, base_shape) 131 | 132 | xt = self.forward_time(xf, base_shape) 133 | xf = self.forward_feature(xt, base_shape) 134 | 135 | 136 | xt_reshaped = xt.reshape(B, base_shape[1], K*L) 137 | xt_proj = F.silu(self.xt_projection(xt_reshaped)) 138 | xt = xt_proj.reshape(B, self.channels, K, L) 139 | 140 | 141 | xf_reshaped = xf.reshape(B, base_shape[1], K*L) 142 | xf_proj = F.silu(self.xf_projection(xf_reshaped)) 143 | xf = xf_proj.reshape(B, self.channels, K, L) 144 | 145 | 146 | x = torch.cat([xt, xf], dim=1) # B, 2*C, K, L 147 | 148 | return x, xt_proj, xf_proj -------------------------------------------------------------------------------- /src/data_loader/elec_tslib_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torch.utils.data import Dataset, DataLoader 4 | 5 | import numpy as np 6 | import torch 7 | import warnings 8 | warnings.filterwarnings('ignore') 9 | from sklearn.preprocessing import StandardScaler 10 | 11 | 12 | '''class StandardScaler(object): 13 | 14 | def __init__(self): 15 | self.mean = 0. 16 | self.std = 1. 17 | 18 | def fit(self, data): 19 | self.mean = data.mean(0) 20 | self.std = data.std(0) 21 | 22 | def transform(self, data): 23 | mean = self.mean 24 | std = self.std 25 | return (data - mean) / std 26 | 27 | def inverse_transform(self, data): 28 | mean = self.mean 29 | std = self.std 30 | return (data * std) + mean''' 31 | 32 | 33 | class Dataset_Custom(Dataset): 34 | 35 | def __init__(self, root_path='dataset/electricity', flag='train', size=None, features='M', data_path='electricity.csv', target='OT', scale=True, 36 | inverse=False, timeenc=1, freq='t', cols=None, percentage=1): 37 | # size [seq_len, label_len, pred_len] 38 | # info 39 | super().__init__() 40 | self.seq_len = size[0] 41 | self.pred_len = size[1] 42 | # init 43 | assert flag in ['train', 'test', 'val'] 44 | type_map = {'train':0, 'val':1, 'test':2} 45 | self.set_type = type_map[flag] 46 | 47 | self.features = features 48 | self.target = target 49 | self.scale = scale 50 | self.inverse = inverse 51 | self.timeenc = timeenc 52 | self.freq = freq 53 | self.percentage = percentage 54 | self.cols=cols 55 | self.root_path = root_path 56 | self.data_path = data_path 57 | self.__read_data__() 58 | 59 | def __read_data__(self): 60 | self.scaler = StandardScaler() 61 | df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path)) 62 | length = len(df_raw)*self.percentage 63 | num_train = int(length*0.7) 64 | num_test = int(length*0.2) 65 | num_vali = int(length*0.1) 66 | # num_vali = len(df_raw) - num_train - num_test 67 | border1s = [0, num_train-self.seq_len, num_train+num_vali-self.seq_len] 68 | border2s = [num_train, num_train+num_vali, num_train+num_vali+num_test] 69 | border1 = border1s[self.set_type] 70 | border2 = border2s[self.set_type] 71 | 72 | if self.features == 'M': 73 | cols_data = df_raw.columns[1:] 74 | df_data = df_raw[cols_data] 75 | elif self.features == 'MS': 76 | cols_data = df_raw.columns[1:] 77 | df_data = df_raw[cols_data] 78 | elif self.features == 'S': 79 | df_data = df_raw[[self.target]] 80 | 81 | if self.scale: 82 | train_data = df_data[border1s[0]:border2s[0]] 83 | self.scaler.fit(train_data.values) 84 | 85 | '''for i in range(len(self.scaler.std)): 86 | if self.scaler.std[i] == 0: 87 | print(i)''' 88 | # print(len(self.scaler.std)) 89 | data = self.scaler.transform(df_data.values) 90 | else: 91 | data = df_data.values 92 | 93 | # df_stamp = df_raw[['date']][border1:border2] 94 | # df_stamp['date'] = pd.to_datetime(df_stamp.date) 95 | df_stamp = pd.date_range(start='4/1/2018',periods=border2-border1, freq='H') 96 | 97 | 98 | self.data_x = data[border1:border2] 99 | if self.inverse: 100 | self.data_y = df_data.values[border1:border2] 101 | else: 102 | self.data_y = data[border1:border2] 103 | 104 | 105 | def __getitem__(self, index): 106 | s_begin = index 107 | s_end = s_begin + self.seq_len 108 | r_begin = s_end 109 | r_end = r_begin + self.pred_len 110 | 111 | seq_x = self.data_x[s_begin:r_end] 112 | 113 | observed_mask = np.ones_like(seq_x) 114 | target_mask=observed_mask.copy() 115 | target_mask[-self.pred_len:] = 0 116 | s = { 117 | 'observed_data': seq_x, 118 | 'observed_mask': observed_mask, 119 | 'gt_mask': target_mask, 120 | 'timepoints': np.arange(self.seq_len+self.pred_len) * 1.0, 121 | 'feature_id': np.arange(seq_x.shape[1]) * 1.0, 122 | } 123 | #return seq_x, seq_y, seq_x_mark, seq_y_mark 124 | return s 125 | 126 | 127 | def __len__(self): 128 | return len(self.data_x) - self.seq_len - self.pred_len + 1 129 | 130 | def inverse_transform(self, data): 131 | return self.scaler.inverse_transform(data) 132 | 133 | 134 | 135 | def get_dataloader_elec(pred_length=96, history_length=192, batch_size=8, device='cuda:0'): 136 | 137 | 138 | train_dataset = Dataset_Custom(flag='train', size=[history_length, pred_length]) 139 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 140 | 141 | valid_dataset = Dataset_Custom(flag='val', size=[history_length, pred_length]) 142 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 143 | 144 | test_dataset = Dataset_Custom(flag='test', size=[history_length, pred_length]) 145 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 146 | print('Lengths', len(train_dataset), len(valid_dataset), len(test_dataset)) 147 | 148 | return train_loader, valid_loader, test_loader, test_dataset 149 | 150 | 151 | def get_dataloader_traffic(pred_length=96, history_length=192, batch_size=8, device='cuda:0'): 152 | 153 | train_dataset = Dataset_Custom(root_path='dataset/traffic', data_path='traffic.csv', flag='train', size=[history_length, pred_length]) 154 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 155 | 156 | valid_dataset = Dataset_Custom(root_path='dataset/traffic', data_path='traffic.csv', flag='val', size=[history_length, pred_length]) 157 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 158 | 159 | test_dataset = Dataset_Custom(root_path='dataset/traffic', data_path='traffic.csv', flag='test', size=[history_length, pred_length]) 160 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 161 | print('Lengths', len(train_dataset), len(valid_dataset), len(test_dataset)) 162 | 163 | return train_loader, valid_loader, test_loader, test_dataset 164 | -------------------------------------------------------------------------------- /src/data_loader/pm25_dataloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from torch.utils.data import DataLoader, Dataset 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class PM25_Dataset(Dataset): 9 | """ 10 | A Dataset class for loading and processing PM2.5 air quality data for use with PyTorch models. 11 | 12 | Parameters: 13 | - eval_length (int): The length of the time series evaluation window (total number of timestamps (L)). 14 | - target_dim (int): The total number of the different features in the multivariate time series (K). 15 | - mode (str): Specifies the mode of the dataset. Can be 'train', 'valid', or 'test'. 16 | - validindex (int): Index used to select the validation month in 'train' and 'valid' modes. 17 | """ 18 | def __init__(self, eval_length=36, target_dim=36, mode="train", validindex=0): 19 | self.eval_length = eval_length 20 | self.target_dim = target_dim 21 | 22 | path = "./data/pm25/pm25_meanstd.pk" 23 | with open(path, "rb") as f: 24 | self.train_mean, self.train_std = pickle.load(f) 25 | if mode == "train": 26 | month_list = [1, 2, 4, 5, 7, 8, 10, 11] 27 | # 1st,4th,7th,10th months are excluded from histmask (since the months are used for creating missing patterns in test dataset) 28 | flag_for_histmask = [0, 1, 0, 1, 0, 1, 0, 1] 29 | month_list.pop(validindex) 30 | flag_for_histmask.pop(validindex) 31 | elif mode == "valid": 32 | month_list = [1, 2, 4, 5, 7, 8, 10, 11] 33 | month_list = month_list[validindex : validindex + 1] 34 | elif mode == "test": 35 | month_list = [3, 6, 9, 12] 36 | self.month_list = month_list 37 | 38 | # create data for batch 39 | self.observed_data = [] # values (separated into each month) 40 | self.observed_mask = [] # masks (separated into each month) 41 | self.gt_mask = [] # ground-truth masks (separated into each month) 42 | self.index_month = [] # indicate month 43 | self.position_in_month = [] # indicate the start position in month (length is the same as index_month) 44 | self.valid_for_histmask = [] # whether the sample is used for histmask 45 | self.use_index = [] # to separate train/valid/test 46 | self.cut_length = [] # excluded from evaluation targets 47 | 48 | df = pd.read_csv( 49 | "./data/pm25/Code/STMVL/SampleData/pm25_ground.txt", 50 | index_col="datetime", 51 | parse_dates=True, 52 | ) 53 | df_gt = pd.read_csv( 54 | "./data/pm25/Code/STMVL/SampleData/pm25_missing.txt", 55 | index_col="datetime", 56 | parse_dates=True, 57 | ) 58 | for i in range(len(month_list)): 59 | current_df = df[df.index.month == month_list[i]] 60 | current_df_gt = df_gt[df_gt.index.month == month_list[i]] 61 | current_length = len(current_df) - eval_length + 1 62 | last_index = len(self.index_month) 63 | self.index_month += np.array([i] * current_length).tolist() 64 | self.position_in_month += np.arange(current_length).tolist() 65 | if mode == "train": 66 | self.valid_for_histmask += np.array( 67 | [flag_for_histmask[i]] * current_length 68 | ).tolist() 69 | 70 | # mask values for observed indices are 1 71 | c_mask = 1 - current_df.isnull().values 72 | c_gt_mask = 1 - current_df_gt.isnull().values 73 | c_data = ( 74 | (current_df.fillna(0).values - self.train_mean) / self.train_std 75 | ) * c_mask 76 | 77 | self.observed_mask.append(c_mask) 78 | self.gt_mask.append(c_gt_mask) 79 | self.observed_data.append(c_data) 80 | 81 | if mode == "test": 82 | n_sample = len(current_df) // eval_length 83 | # interval size is eval_length (missing values are imputed only once) 84 | c_index = np.arange( 85 | last_index, last_index + eval_length * n_sample, eval_length 86 | ) 87 | self.use_index += c_index.tolist() 88 | self.cut_length += [0] * len(c_index) 89 | if len(current_df) % eval_length != 0: # avoid double-count for the last time-series 90 | self.use_index += [len(self.index_month) - 1] 91 | self.cut_length += [eval_length - len(current_df) % eval_length] 92 | 93 | if mode != "test": 94 | self.use_index = np.arange(len(self.index_month)) 95 | self.cut_length = [0] * len(self.use_index) 96 | 97 | # masks for 1st,4th,7th,10th months are used for creating missing patterns in test data, 98 | # so these months are excluded from histmask to avoid leakage 99 | if mode == "train": 100 | ind = -1 101 | self.index_month_histmask = [] 102 | self.position_in_month_histmask = [] 103 | 104 | for i in range(len(self.index_month)): 105 | while True: 106 | ind += 1 107 | if ind == len(self.index_month): 108 | ind = 0 109 | if self.valid_for_histmask[ind] == 1: 110 | self.index_month_histmask.append(self.index_month[ind]) 111 | self.position_in_month_histmask.append( 112 | self.position_in_month[ind] 113 | ) 114 | break 115 | else: # dummy (histmask is only used for training) 116 | self.index_month_histmask = self.index_month 117 | self.position_in_month_histmask = self.position_in_month 118 | 119 | def __getitem__(self, org_index): 120 | """ 121 | Returns a sample from the dataset at the specified index. 122 | 123 | Parameters: 124 | - org_index (int): The original index of the sample to retrieve. 125 | 126 | Returns: 127 | - dict: A dictionary containing the data sample including observed data, masks, and timepoints, and a mask to avoid double evaluation of overlapping values. 128 | """ 129 | index = self.use_index[org_index] 130 | c_month = self.index_month[index] 131 | c_index = self.position_in_month[index] 132 | hist_month = self.index_month_histmask[index] 133 | hist_index = self.position_in_month_histmask[index] 134 | s = { 135 | "observed_data": self.observed_data[c_month][ 136 | c_index : c_index + self.eval_length 137 | ], 138 | "observed_mask": self.observed_mask[c_month][ 139 | c_index : c_index + self.eval_length 140 | ], 141 | "gt_mask": self.gt_mask[c_month][ 142 | c_index : c_index + self.eval_length 143 | ], 144 | "hist_mask": self.observed_mask[hist_month][ 145 | hist_index : hist_index + self.eval_length 146 | ], 147 | "timepoints": np.arange(self.eval_length), 148 | "cut_length": self.cut_length[org_index], 149 | } 150 | 151 | return s 152 | 153 | def __len__(self): 154 | """ 155 | Returns the total number of samples in the dataset. 156 | 157 | Returns: 158 | - int: Total number of samples. 159 | """ 160 | return len(self.use_index) 161 | 162 | 163 | def get_dataloader_pm25(batch_size, device, validindex=0): 164 | """ 165 | Prepares DataLoader objects for the PM2.5 dataset for training, validation, and testing. Also, returns 166 | normalization parameters (scaler and mean_scaler) to normalize the data. 167 | 168 | Parameters: 169 | - batch_size (int): Batch size for the DataLoader. 170 | - device (torch.device): The device on which the tensors will be loaded. 171 | - validindex (int): Index used to select the validation month in 'train' and 'valid' modes. 172 | 173 | Returns: 174 | - Tuple: Contains DataLoader objects for training, validation, and testing, along with normalization parameters (scaler, mean_scaler). 175 | """ 176 | dataset = PM25_Dataset(mode="train", validindex=validindex) 177 | train_loader = DataLoader( 178 | dataset, batch_size=batch_size, num_workers=1, shuffle=True 179 | ) 180 | dataset_test = PM25_Dataset(mode="test", validindex=validindex) 181 | test_loader = DataLoader( 182 | dataset_test, batch_size=batch_size, num_workers=1, shuffle=False 183 | ) 184 | dataset_valid = PM25_Dataset(mode="valid", validindex=validindex) 185 | valid_loader = DataLoader( 186 | dataset_valid, batch_size=batch_size, num_workers=1, shuffle=False 187 | ) 188 | 189 | scaler = torch.from_numpy(dataset.train_std).to(device).float() 190 | mean_scaler = torch.from_numpy(dataset.train_mean).to(device).float() 191 | 192 | return train_loader, valid_loader, test_loader, scaler, mean_scaler 193 | 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
This repository will not be maintained any further, and issues and pull requests may be ignored. For an up to date codebase, issues, and pull requests, please continue to the new repository.
4 |
12 |
13 | **Self-Supervised Learning of Time Series Representation via Diffusion Process and Imputation-Interpolation-Forecasting Mask**
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
25 | Usage • 26 | Examples • 27 | Checkpoints • 28 | Processed Datasets • 29 | Citation 30 |
31 |
46 |
47 |
48 | ## Usage
49 |
50 | We recommend to start with installing dependencies in an virtual environment.
51 | ```
52 | conda create --name tsde python=3.11 -y
53 | conda activate tsde
54 | pip install -r requirements.txt
55 | ```
56 |
57 | ### Datasets
58 | Download public datasets used in our experiments:
59 |
60 | ```
61 | python src/utils/download_data.py [dataset-name]
62 | ```
63 | Options of [dataset-name]: physio, pm25, electricity, solar, traffic, taxi, wiki, msl, smd, smap, swat and psm
64 |
65 | ### Imputation
66 |
67 | To run the imputation experiments on PhysioNet dataset:
68 | ```
69 | python src/experiments/train_test_imputation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio [test_missing_ratio]
70 | ```
71 | In our experiments, we set [test_missing_ratio] to: 0.1, 0.5, and 0.9
72 |
73 | To run the imputation experiments on PM2.5 dataset, first set train-epochs to 1500, and finetuning-epochs to 100 in src/config/base.yaml, and run the following command:
74 | ```
75 | python src/experiments/train_test_imputation.py --device [device] --dataset Pm25
76 | ```
77 |
78 | ### Interpolation
79 | To run the imputation experiments on PhysioNet dataset:
80 | ```
81 | python src/experiments/train_test_interpolation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio [test_missing_ratio]
82 | ```
83 | In our experiments, we set [test_missing_ratio] to: 0.1, 0.5, and 0.9
84 |
85 | ### Forecasting
86 | Please first set the number of pretraining and finetuning epochs for each dataset in src/config/base_forecasting.yaml, and set the number of features for subsampling training in the TSDE_forecasting model in src/model/main_model.py.
87 | Run the following command:
88 | ```
89 | python src/experiments/train_test_forecasting.py --dataset [dataset-name] --device [device] --sample_feat
90 | ```
91 | Please remove the flag --sample_feat to disable the sub-sampling of features during training.
92 | ### Anomaly Detection
93 | Please first set the number of features, the number of pretraining and finetuning epochs for each dataset in src/config/base_ad.yaml.
94 | Run the following command:
95 | ```
96 | python src/experiments/train_test_anomaly_detection.py --dataset [dataset-name] --device [device] --seed [seed] --anomaly_ratio [anomaly_ratio]
97 | ```
98 |
99 | The values of [dataset-name], [seed] and [anomaly_ratio] used in our experiments are available in our paper.
100 |
101 | ### Classification on PhysioNet
102 | Run the following command:
103 | ```
104 | python src/experiments/train_test_classification.py --seed [seed] --device [device] --testmissingratio [test_missing_ratio]
105 | ```
106 | ### Benchmarking TSDE against Time Series Library models
107 | 1. Upload the Electricity dataset following their guidelines available [here](https://github.com/thuml/Time-Series-Library). The dataset folder should be in the root directory.
108 | 2. Run the following command:
109 | ```
110 | python src/experiments/train_test_tslib_forecasting.py --device [device] --pred_length [pred_length] --hist_length [hist_length]
111 | ```
112 | The values of [pred_length], and [hist_length] used in our experiments are available in our paper.
113 |
114 | ## Examples
115 | #### Examples of imputation, interpolation and forecasting
116 |
117 |
118 | #### Example of clustering
119 |
120 |
121 | #### Example of embedding visualization
122 |
123 |
124 | ## Checkpoints
125 | To run the evaluation using a specific checkpoint, follow the instructions below. Ensure your environment is set up correctly for running and the datasets are downloaded first.
126 | ### Imputation
127 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Imputation.zip). Place the content of this folder under [root_dir]/save.
128 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'.
129 | ```
130 | python src/experiments/train_test_imputation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio 0.1 --modelfolder [path_to_checkpoint_folder] --run [run_number]
131 | ```
132 | ```
133 | python src/experiments/train_test_imputation.py --device [device] --dataset Pm25 --modelfolder [path_to_checkpoint_folder] --run [run_number]
134 | ```
135 |
136 | ### Interpolation
137 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Interpolation.zip). Place the content of this folder under [root_dir]/save.
138 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'.
139 | ```
140 | python src/experiments/train_test_interpolation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio 0.1 --modelfolder [path_to_checkpoint_folder] --run [run_number]
141 | ```
142 | ### Forecasting
143 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Forecasting.zip). Place the content of this folder under [root_dir]/save.
144 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'.
145 | ```
146 | python src/experiments/train_test_forecasting.py --device [device] --dataset [dataset-name] --modelfolder [path_to_checkpoint_folder] --run [run_number]
147 | ```
148 |
149 | ### Anomaly Detection
150 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Anomaly_Detection.zip). Place the content of this folder under [root_dir]/save.
151 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'.
152 | ```
153 | python src/experiments/train_test_anomaly_detection.py --device [device] --dataset [dataset-name] --modelfolder [path_to_checkpoint_folder] --run [run_number] --disable_finetune
154 | ```
155 |
156 | ### Classification
157 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Classification.zip). Place the content of this folder under [root_dir]/save.
158 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should includes '[root_dir]/save' and excludes 'model.pth'.
159 | ```
160 | python src/experiments/train_test_classification.py --device [device] --modelfolder [path_to_checkpoint_folder] --run [run_number] --disable_finetune
161 | ```
162 | ## Citation
163 | ```
164 | @article{senane2024tsde,
165 | title={{Self-Supervised Learning of Time Series Representation via Diffusion Process and Imputation-Interpolation-Forecasting Mask}},
166 | author={Senane, Zineb and Cao, Lele and Buchner, Valentin Leonhard and Tashiro, Yusuke and You, Lei and Herman, Pawel and Nordahl, Mats and Tu, Ruibo and von Ehrenheim, Vilhelm},
167 | year={2024},
168 | eprint={2405.05959},
169 | archivePrefix={arXiv},
170 | primaryClass={cs.LG}
171 | }
172 | ```
173 |
--------------------------------------------------------------------------------
/src/utils/download_data.py:
--------------------------------------------------------------------------------
1 | import tarfile
2 | import zipfile
3 | import sys
4 | import os
5 | import wget
6 | import requests
7 | import pandas as pd
8 | import pickle
9 | import gdown
10 |
11 | # Define a root URL for downloading forecasting datasets used in Salinas et al. https://proceedings.neurips.cc/paper/2019/file/0b105cf1504c4e241fcc6d519ea962fb-Paper.pdf
12 | root = "https://raw.githubusercontent.com/mbohlkeschneider/gluon-ts/mv_release/datasets/"
13 |
14 |
15 | def get_confirm_token(response):
16 | """
17 | Extract the confirmation token from the response cookies, which is required for some downloads.
18 |
19 | Parameters:
20 | - response: The response object from a requests session.
21 |
22 | Returns:
23 | - The confirmation token if found; otherwise, None.
24 | """
25 | for key, value in response.cookies.items():
26 | if key.startswith('download_warning'):
27 | return value
28 |
29 | return None
30 |
31 | def save_response_content(response, destination):
32 | """
33 | Saves the content of a response object to a file in chunks, which is useful for large downloads.
34 |
35 | Parameters:
36 | - response: The response object from a requests session.
37 | - destination: The file path where the content will be saved.
38 | """
39 |
40 | CHUNK_SIZE = 32768
41 |
42 | with open(destination, "wb") as f:
43 | for chunk in response.iter_content(CHUNK_SIZE):
44 | if chunk: # filter out keep-alive new chunks
45 | f.write(chunk)
46 |
47 |
48 | def download_file_from_google_drive(url, destination):
49 | """
50 | Downloads a file from Google Drive by handling the confirmation token and saving the response content.
51 |
52 | Parameters:
53 | - url: The Google Drive file URL.
54 | - destination: The file path where the downloaded data will be saved.
55 | """
56 | session = requests.Session()
57 |
58 | response = session.get(url, stream = True)
59 | token = get_confirm_token(response)
60 |
61 | if token:
62 | params = { 'confirm' : token }
63 | response = session.get(url, params = params, stream = True)
64 | save_response_content(response, destination)
65 |
66 | # Ensure the data directory exists
67 | os.makedirs("data/", exist_ok=True)
68 |
69 | # Download and extract datasets based on the command-line argument
70 | if sys.argv[1] == "physio":
71 | # Install PhysioNet dataset
72 |
73 | url = "https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz?download"
74 | url_outcomes = "https://physionet.org/files/challenge-2012/1.0.0/Outcomes-a.txt?download"
75 | os.makedirs("data/physio", exist_ok=True)
76 | wget.download(url, out="data/")
77 | wget.download(url_outcomes, out="data/physio")
78 | with tarfile.open("data/set-a.tar.gz", "r:gz") as t:
79 | t.extractall(path="data/physio")
80 | print(f"Downloaded data/physio")
81 | os.remove("data/set-a.tar.gz")
82 |
83 |
84 | elif sys.argv[1] == "pm25":
85 | # Install PM2.5 dataset
86 |
87 | url = "https://www.microsoft.com/en-us/research/uploads/prod/2016/06/STMVL-Release.zip"
88 |
89 | headers = {
90 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
91 | }
92 | filename = "data/STMVL-Release.zip"
93 | with requests.get(url, stream=True, headers=headers) as response:
94 | # Check if the request was successful (status code 200)
95 | if response.status_code == 200:
96 | with open(filename, 'wb') as f:
97 | # Write the content of the file to the local file
98 | f.write(response.content)
99 | print(f"Downloaded {filename}")
100 | else:
101 | print(f"Failed to download the file. Status code: {response.status_code}")
102 | with zipfile.ZipFile(filename) as z:
103 | z.extractall("data/pm25")
104 |
105 | os.remove(filename)
106 |
107 | def create_normalizer_pm25():
108 | """
109 | Normalizes the PM2.5 dataset by calculating and saving the mean and standard deviation, excluding test months.
110 |
111 |
112 | It saves these values for use in normalizing the training and testing data.
113 | """
114 | df = pd.read_csv(
115 | "./data/pm25/Code/STMVL/SampleData/pm25_ground.txt",
116 | index_col="datetime", # Use the datetime column as the index
117 | parse_dates=True,
118 | )
119 |
120 | # Define the months to exclude from the normalization process (test months)
121 | test_month = [3, 6, 9, 12]
122 | # Exclude the test months from the dataset
123 | for i in test_month:
124 | df = df[df.index.month != i]
125 | # Calculate the mean and standard deviation for the dataset after excluding the test months
126 | mean = df.describe().loc["mean"].values
127 | std = df.describe().loc["std"].values
128 |
129 | # Save the mean and standard deviation to a file using pickle
130 | path = "./data/pm25/pm25_meanstd.pk"
131 | with open(path, "wb") as f:
132 | pickle.dump([mean, std], f)
133 | create_normalizer_pm25()
134 |
135 | elif sys.argv[1] == "electricity":
136 | # Install Electricity dataset
137 |
138 | url = root + "electricity_nips.tar.gz?download"
139 | wget.download(url, out="data")
140 | with tarfile.open("data/electricity_nips.tar.gz", "r:gz") as t:
141 | t.extractall(path="data/electricity")
142 | print(f"Downloaded data/electricity")
143 | os.remove("data/electricity_nips.tar.gz")
144 |
145 | elif sys.argv[1] == "solar":
146 | # Install Solar dataset
147 |
148 | url = root + "solar_nips.tar.gz?download"
149 | wget.download(url, out="data")
150 | with tarfile.open("data/solar_nips.tar.gz", "r:gz") as t:
151 | t.extractall(path="data/solar")
152 | print(f"Downloaded data/solar")
153 | os.remove("data/solar_nips.tar.gz")
154 | os.rename("data/solar/solar_nips/train/train.json", "data/solar/solar_nips/train/data.json")
155 | os.rename("data/solar/solar_nips/test/test.json", "data/solar/solar_nips/test/data.json")
156 |
157 | elif sys.argv[1] == "traffic":
158 | # Install Traffic dataset
159 |
160 | url = root + "traffic_nips.tar.gz?download"
161 | wget.download(url, out="data")
162 | with tarfile.open("data/traffic_nips.tar.gz", "r:gz") as t:
163 | t.extractall(path="data/traffic")
164 | print(f"Downloaded data/traffic")
165 | os.remove("data/traffic_nips.tar.gz")
166 |
167 | elif sys.argv[1] == "taxi":
168 | # Install Taxi dataset
169 |
170 | url = root + "taxi_30min.tar.gz?download"
171 | wget.download(url, out="data")
172 | with tarfile.open("data/taxi_30min.tar.gz", "r:gz") as t:
173 | t.extractall(path="data/taxi")
174 | print(f"Downloaded data/taxi")
175 | os.remove("data/taxi_30min.tar.gz")
176 | os.rename("data/taxi/taxi_30min/", "data/taxi/taxi_nips")
177 | os.rename("data/taxi/taxi_nips/train/train.json", "data/taxi/taxi_nips/train/data.json")
178 | os.rename("data/taxi/taxi_nips/test/test.json", "data/taxi/taxi_nips/test/data.json")
179 |
180 |
181 | elif sys.argv[1] == "wiki":
182 | # Install Wiki dataset
183 |
184 | url = "https://github.com/awslabs/gluonts/raw/1553651ca1fca63a16e012b8927bd9ce72b8e79e/datasets/wiki-rolling_nips.tar.gz"
185 | wget.download(url, out="data")
186 | with tarfile.open("data/wiki-rolling_nips.tar.gz", "r:gz") as t:
187 | t.extractall(path="data/wiki")
188 | print(f"Downloaded data/wiki")
189 | os.remove("data/wiki-rolling_nips.tar.gz")
190 | os.rename("data/wiki/wiki-rolling_nips/", "data/wiki/wiki_nips")
191 | os.rename("data/wiki/wiki_nips/train/train.json", "data/wiki/wiki_nips/train/data.json")
192 | os.rename("data/wiki/wiki_nips/test/test.json", "data/wiki/wiki_nips/test/data.json")
193 |
194 | elif sys.argv[1] == "msl":
195 | # Install MSL dataset
196 |
197 | url = 'https://drive.google.com/uc?id=14STjpszyi6D0B7BUHZ1L4GLUkhhPXE0G'
198 | gdown.download(url, 'MSL.zip', quiet=False)
199 | with zipfile.ZipFile('MSL.zip') as z:
200 | z.extractall("data/anomaly_detection")
201 | os.remove('MSL.zip')
202 | print(f"Downloaded MSL dataset")
203 |
204 | elif sys.argv[1] == "psm":
205 | # Install PSM dataset
206 |
207 | url = "https://drive.google.com/uc?id=14gCVQRciS2hs2SAjXpqioxE4CUzaYkhb"
208 | gdown.download(url, 'PSM.zip', quiet=False)
209 | with zipfile.ZipFile('PSM.zip') as z:
210 | z.extractall("data/anomaly_detection")
211 | os.remove('PSM.zip')
212 | print(f"Downloaded PSM dataset")
213 |
214 | elif sys.argv[1] == "smap":
215 | # Install SMAP dataset
216 |
217 | url = "https://drive.google.com/uc?id=1kxiTMOouw1p-yJMkb_Q_CGMjakVNtg3X"
218 | gdown.download(url, 'SMAP.zip', quiet=False)
219 | with zipfile.ZipFile('SMAP.zip') as z:
220 | z.extractall("data/anomaly_detection")
221 | os.remove('SMAP.zip')
222 | print(f"Downloaded SMAP dataset")
223 |
224 |
225 | elif sys.argv[1] == "smd":
226 | # Install SMD dataset
227 |
228 | url = "https://drive.google.com/uc?id=1BgjQ7_2uqRrZ789Pijtpid5xpLTniywu"
229 | gdown.download(url, 'SMD.zip', quiet=False)
230 | with zipfile.ZipFile('SMD.zip') as z:
231 | z.extractall("data/anomaly_detection")
232 | os.remove('SMD.zip')
233 | print(f"Downloaded SMD dataset")
234 |
235 |
236 | elif sys.argv[1] == "swat":
237 | # Install SWaT dataset
238 |
239 | url = "https://drive.google.com/uc?id=1eRKQwJhqmUD4LkWnqNy1cdIz3W_y6EtW"
240 | gdown.download(url, 'SWaT.zip', quiet=False)
241 | with zipfile.ZipFile('SWaT.zip') as z:
242 | z.extractall("data/anomaly_detection")
243 | os.remove('SWaT.zip')
244 | print(f"Downloaded SWaT dataset")
245 |
246 |
247 |
--------------------------------------------------------------------------------
/src/utils/masking_strategies.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 |
5 |
6 |
7 | def imputation_mask_batch(observed_mask):
8 | """
9 | Generates a batch of masks for imputation task where certain observations are randomly masked based on a
10 | sample-specific ratio.
11 |
12 | Parameters:
13 | - observed_mask (Tensor): A tensor indicating observed values (1 for observed, 0 for missing).
14 |
15 | Returns:
16 | - Tensor: A mask tensor for imputation with the same shape as `observed_mask`.
17 | """
18 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values
19 | rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1)
20 | min_value, max_value = 0.1, 0.9
21 | for i in range(len(observed_mask)):
22 | sample_ratio = min_value + (max_value - min_value)*np.random.rand() # missing ratio ## at random
23 | num_observed = observed_mask[i].sum().item()
24 | num_masked = round(num_observed * sample_ratio)
25 | rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1
26 | cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float()
27 | return cond_mask
28 |
29 | def interpolation_mask_batch(observed_mask):
30 | """
31 | Generates a batch of masks for interpolation task by randomly selecting timestamps to mask across all features.
32 |
33 | Parameters:
34 | - observed_mask (Tensor): A tensor indicating observed values.
35 |
36 | Returns:
37 | - Tensor: A mask tensor for interpolation tasks.
38 | """
39 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values
40 | total_timestamps = observed_mask.shape[2]
41 | timestamps = np.arange(total_timestamps)
42 | for i in range(len(observed_mask)):
43 | mask_timestamp = np.random.choice(
44 | timestamps
45 | )
46 | rand_for_mask[i][:,mask_timestamp] = -1
47 | cond_mask = (rand_for_mask > 0).float()
48 | return cond_mask
49 |
50 |
51 | def forecasting_mask_batch(observed_mask):
52 | """
53 | Generates a batch of masks for forecasting task by masking out all future values beyond a randomly selected start
54 | point in the sequence (30% timestamps at most).
55 |
56 | Parameters:
57 | - observed_mask (Tensor): A tensor indicating observed values.
58 |
59 | Returns:
60 | - Tensor: A mask tensor for forecasting tasks.
61 | """
62 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values
63 | total_timestamps = observed_mask.shape[2]
64 | start_pred_timestamps = round(total_timestamps/3)
65 | timestamps = np.arange(total_timestamps)[start_pred_timestamps:]
66 | for i in range(len(observed_mask)):
67 | start_forecast_mask = np.random.choice(
68 | timestamps
69 | )
70 | rand_for_mask[i][:,-start_forecast_mask:] = -1
71 | cond_mask = (rand_for_mask > 0).float()
72 | return cond_mask
73 |
74 |
75 | def forecasting_imputation_mask_batch(observed_mask):
76 | """
77 | Generates a batch of masks for forecasting/imputation task by masking out all
78 | future values for a random subset of features beyond a randomly selected start
79 | point in the sequence (30% timestamps at most).
80 |
81 | Parameters:
82 | - observed_mask (Tensor): A tensor indicating observed values.
83 |
84 | Returns:
85 | - Tensor: A mask tensor for forecasting tasks.
86 | """
87 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values
88 | total_timestamps = observed_mask.shape[2]
89 | start_pred_timestamps = round(total_timestamps/3)
90 | timestamps = np.arange(total_timestamps)[start_pred_timestamps:]
91 |
92 | for i in range(len(observed_mask)):
93 | batch_indices = list(np.arange(0, len(rand_for_mask[i])))
94 | n_keep_dims = random.choice([1, 2, 3]) # pick how many dims to keep unmasked
95 | keep_dims_idx = random.sample(batch_indices, n_keep_dims) # choose the dims to keep
96 | mask_dims_idx = [i for i in batch_indices if i not in keep_dims_idx] # choose the dims to mask
97 | start_forecast_mask = np.random.choice(
98 | timestamps
99 | )
100 | rand_for_mask[i][mask_dims_idx, -start_forecast_mask:] = -1
101 | cond_mask = (rand_for_mask > 0).float()
102 | return cond_mask
103 |
104 |
105 | def imputation_mask_sample(observed_mask):
106 | """
107 | Generates a mask for imputation for a single sample, similar to `imputation_mask_batch` but for an individual sample.
108 |
109 | Parameters:
110 | - observed_mask (Tensor): A tensor indicating observed values for a single sample.
111 |
112 | Returns:
113 | - Tensor: A mask tensor for imputation for the sample.
114 | """
115 | ## Observed mask of shape KxL
116 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values
117 |
118 | rand_for_mask = rand_for_mask.reshape(-1)
119 | min_value, max_value = 0.1, 0.9
120 | sample_ratio = min_value + (max_value - min_value)*np.random.rand()
121 | num_observed = observed_mask.sum().item()
122 | num_masked = round(num_observed * sample_ratio)
123 | rand_for_mask[rand_for_mask.topk(num_masked).indices] = -1
124 | cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float()
125 | return cond_mask
126 |
127 | def interpolation_mask_sample(observed_mask):
128 | """
129 | Generates a mask for interpolation for a single sample by randomly selecting a timestamp to mask.
130 |
131 | Parameters:
132 | - observed_mask (Tensor): A tensor indicating observed values for a single sample.
133 |
134 | Returns:
135 | - Tensor: A mask tensor for interpolation for the sample.
136 | """
137 | ## Observed mask of shape KxL
138 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values
139 | total_timestamps = observed_mask.shape[1]
140 | timestamps = np.arange(total_timestamps)
141 |
142 | mask_timestamp = np.random.choice(
143 | timestamps
144 | )
145 | rand_for_mask[:,mask_timestamp] = -1
146 | cond_mask = (rand_for_mask > 0).float()
147 | return cond_mask
148 |
149 |
150 | def forecasting_mask_sample(observed_mask):
151 | """
152 | Generates a mask for forecasting for a single sample by masking out all future values beyond a selected timestamp.
153 |
154 | Parameters:
155 | - observed_mask (Tensor): A tensor indicating observed values for a single sample.
156 |
157 | Returns:
158 | - Tensor: A mask tensor for forecasting for the sample.
159 | """
160 | ## Observed mask of shape KxL
161 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values
162 | total_timestamps = observed_mask.shape[1]
163 |
164 | start_pred_timestamps = round(total_timestamps/3)
165 | timestamps = np.arange(total_timestamps)[-start_pred_timestamps:]
166 |
167 | start_forecast_mask = np.random.choice(
168 | timestamps
169 | )
170 | rand_for_mask[:,start_forecast_mask:] = -1
171 | cond_mask = (rand_for_mask > 0).float()
172 |
173 | return cond_mask
174 |
175 |
176 |
177 | def get_mask_equal_p_sample(observed_mask):
178 | """
179 | IIF mix masking strategy.
180 | Generates masks for a batch of samples where each sample has an equal probability of being assigned one of the
181 | three mask types: imputation, interpolation, or forecasting.
182 |
183 | Parameters:
184 | - observed_mask (Tensor): A tensor indicating observed values for a batch of samples.
185 |
186 | Returns:
187 | - Tensor: A batch of masks with a mix of the three types.
188 | """
189 | B, K, L = observed_mask.shape
190 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask
191 | for i in range(B):
192 |
193 | threshold = 1/3
194 |
195 | imp_mask = imputation_mask_sample(observed_mask[i])
196 | p = np.random.rand() # missing probability at random
197 |
198 | if p