├── figures ├── tsde.jpg ├── tsde.pdf ├── clustering.jpg ├── embed_viz.png ├── tsde_logo.png ├── pred_example.jpg └── forecasting_pred.pdf ├── requirements.txt ├── CITATION.md ├── LICENSE ├── src ├── config │ ├── base_ad.yaml │ ├── base_forecasting.yaml │ ├── base_classification.yaml │ └── base.yaml ├── base │ ├── diffEmbedding.py │ ├── denoisingNetwork.py │ └── mtsEmbedding.py ├── experiments │ ├── train_test_anomaly_detection.py │ ├── train_test_tslib_forecasting.py │ ├── train_test_interpolation.py │ ├── train_test_imputation.py │ ├── train_test_classification.py │ └── train_test_forecasting.py ├── utils │ ├── metrics.py │ ├── download_data.py │ ├── masking_strategies.py │ └── utils.py ├── data_loader │ ├── elec_tslib_dataloader.py │ ├── pm25_dataloader.py │ ├── physio_dataloader.py │ ├── forecasting_dataloader.py │ └── anomaly_detection_dataloader.py └── tsde │ └── main_model.py ├── .gitignore └── README.md /figures/tsde.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/tsde.jpg -------------------------------------------------------------------------------- /figures/tsde.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/tsde.pdf -------------------------------------------------------------------------------- /figures/clustering.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/clustering.jpg -------------------------------------------------------------------------------- /figures/embed_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/embed_viz.png -------------------------------------------------------------------------------- /figures/tsde_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/tsde_logo.png -------------------------------------------------------------------------------- /figures/pred_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/pred_example.jpg -------------------------------------------------------------------------------- /figures/forecasting_pred.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EQTPartners/TSDE/HEAD/figures/forecasting_pred.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | pandas 4 | requests 5 | scikit-learn 6 | scipy 7 | tqdm 8 | wget # Note: 'wget' package might be named differently in Conda 9 | --extra-index-url https://download.pytorch.org/whl/cu118 ## if cuda is available if not comment to use cpu 10 | torch==2.2.1+cu118 11 | pyyaml 12 | gdown -------------------------------------------------------------------------------- /CITATION.md: -------------------------------------------------------------------------------- 1 | # Citation 2 | 3 | If you use or refer to this repository in your research, please cite our paper: 4 | 5 | ### BibTeX 6 | ```bash 7 | @inproceedings{senane2024tsde, 8 | title={{Self-Supervised Learning of Time Series Representation via Diffusion Process and Imputation-Interpolation-Forecasting Mask}}, 9 | author={Senane, Zineb and Cao, Lele and Buchner, Valentin Leonhard and Tashiro, Yusuke and You, Lei and Herman, Pawel and Nordahl, Mats and Tu, Ruibo and von Ehrenheim, Vilhelm}, 10 | booktitle={to appear In KDD 24: Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining.}, 11 | year={2024} 12 | } 13 | ``` 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 EQT Partners AB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/config/base_ad.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 250 ## Pre-raining epochs using IIF masking 5 | batch_size: 32 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 25 | featureemb: 16 26 | num_feat: 51 ## Total number of features in the MTS (K) 27 | num_timestamps: 100 ## Total number of timestamps in the MTS (L) 28 | classes: 2 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 30 ## Number of finetuning epochs for the downstream task 34 | 35 | -------------------------------------------------------------------------------- /src/config/base_forecasting.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 400 ## Pre-raining epochs using IIF masking 5 | batch_size: 8 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 25 | featureemb: 16 26 | num_feat: 35 ## Total number of features in the MTS (K) 27 | num_timestamps: 48 ## Total number of timestamps in the MTS (L) 28 | classes: 1 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 20 ## Number of finetuning epochs for the downstream task 34 | -------------------------------------------------------------------------------- /src/config/base_classification.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 500 ## Pre-raining epochs using IIF masking 5 | batch_size: 4 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 25 | featureemb: 16 26 | num_feat: 35 ## Total number of features in the MTS (K) 27 | num_timestamps: 48 ## Total number of timestamps in the MTS (L) 28 | classes: 2 ## Number of classes 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 50 ## Number of finetuning epochs for the downstream task 34 | 35 | -------------------------------------------------------------------------------- /src/config/base.yaml: -------------------------------------------------------------------------------- 1 | #type: args 2 | 3 | train: 4 | epochs: 1500 ## Pre-raining epochs using IIF masking 5 | batch_size: 16 6 | lr: 1.0e-3 7 | 8 | diffusion: 9 | layers: 4 ## Number of residual layers in the denoising block 10 | channels: 64 ## Number of channels for projections in the denoising block (residual channels) 11 | diffusion_embedding_dim: 128 ## Diffusion step embedding dimension 12 | beta_start: 0.0001 ## minimum noise level in the forward pass 13 | beta_end: 0.5 ## maximum noise level in the forward pass 14 | num_steps: 50 ## Total number of diffusion steps 15 | schedule: "quad" ## Type of noise scheduler 16 | 17 | model: 18 | timeemb: 128 ## Time embedding dimension 19 | featureemb: 16 ## Feature embedding dimension 20 | mix_masking_strategy: "equal_p" ## Mix masking strategy 21 | time_strategy: "hawkes" ## Time embedding type 22 | 23 | embedding: 24 | timeemb: 128 ## Time embedding dimension, needed as parameter for the embedding block, same as the one model 25 | featureemb: 16 ## Feature embedding dimension, needed as parameter for the embedding block, same as the one model 26 | num_feat: 35 ## Total number of features in the MTS (K) 27 | num_timestamps: 48 ## Total number of timestamps in the MTS (L) 28 | classes: 1 29 | channels: 16 ## Number of embedding dimension in both temporal and spatial encoders 30 | nheads: 8 ## Number of heads in the temporal and spatial encoders 31 | 32 | finetuning: 33 | epochs: 100 ## Number of finetuning epochs for the downstream task 34 | -------------------------------------------------------------------------------- /src/base/diffEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DiffusionEmbedding(nn.Module): 6 | """ 7 | A neural network module for creating embeddings for diffusion steps. This module is designed to encode the diffusion 8 | steps into a continuous space for the reverse pass of the DDPM (denoising block in TSDE). 9 | 10 | Parameters: 11 | - num_steps (int): The number of diffusion steps or time steps to be encoded (T=50). 12 | - embedding_dim (int, optional): The dimensionality of the embedding space. (we set it to 128). 13 | - projection_dim (int, optional): The dimensionality of the projected embedding space. If not specified, 14 | it defaults to the same as `embedding_dim`. 15 | 16 | The embedding for a given diffusion step is produced by first generating a sinusoidal embedding of the step, 17 | followed by projecting this embedding through two linear layers with SiLU (Sigmoid Linear Unit) and ReLU 18 | activations, respectively. 19 | """ 20 | 21 | def __init__(self, num_steps, embedding_dim=128, projection_dim=None): 22 | super().__init__() 23 | if projection_dim is None: 24 | projection_dim = embedding_dim 25 | self.register_buffer( 26 | "embedding", 27 | self._build_embedding(num_steps, embedding_dim / 2), 28 | persistent=False, 29 | ) 30 | self.projection1 = nn.Linear(embedding_dim, projection_dim) 31 | self.projection2 = nn.Linear(projection_dim, projection_dim) 32 | 33 | def forward(self, diffusion_step): 34 | """ 35 | Defines the forward pass for projecting the diffusion embedding of a specific diffusion step t. 36 | 37 | Parameters: 38 | - diffusion_step: An integer indicating the diffusion step for which embeddings are generated. 39 | 40 | Returns: 41 | - Tensor: The projected embedding for the given diffusion step. 42 | """ 43 | x = self.embedding[diffusion_step] 44 | x = self.projection1(x) 45 | x = F.silu(x) 46 | x = self.projection2(x) 47 | x = F.relu(x) 48 | return x 49 | 50 | def _build_embedding(self, num_steps, dim=64): 51 | """ 52 | Builds the sinusoidal embedding table for diffusion steps as in CSDI (https://arxiv.org/pdf/2107.03502.pdf). 53 | 54 | Parameters: 55 | - num_steps (int): The number of diffusion steps to encode (T=50). 56 | - dim (int): The dimensionality of the sinusoidal embedding before doubling (due to sin and cos). 57 | 58 | Returns: 59 | - Tensor: A tensor of shape (num_steps, embedding_dim) containing the sinusoidal embeddings of all diffusion steps. 60 | """ 61 | steps = torch.arange(num_steps).unsqueeze(1) # (T,1) 62 | frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(0) # (1,dim) 63 | table = steps * frequencies # (T,dim) 64 | table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2) 65 | return table -------------------------------------------------------------------------------- /src/experiments/train_test_anomaly_detection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | import random 8 | import numpy as np 9 | import torch.nn as nn 10 | 11 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 12 | sys.path.append(root_dir) 13 | 14 | from data_loader.anomaly_detection_dataloader import anomaly_detection_dataloader 15 | from tsde.main_model import TSDE_AD 16 | from utils.utils import train, finetune, evaluate_finetuning, gsutil_cp, set_seed 17 | 18 | torch.backends.cudnn.enabled = False 19 | 20 | parser = argparse.ArgumentParser(description="TSDE-Anomaly Detection") 21 | parser.add_argument("--config", type=str, default="base_ad.yaml") 22 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 23 | parser.add_argument("--seed", type=int, default=1) 24 | parser.add_argument("--disable_finetune", action="store_true") 25 | 26 | 27 | parser.add_argument("--dataset", type=str, default='SMAP') 28 | parser.add_argument("--modelfolder", type=str, default="") 29 | parser.add_argument("--run", type=int, default=1) 30 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 31 | parser.add_argument("--anomaly_ratio", type=float, default=1, help="Anomaly ratio") 32 | args = parser.parse_args() 33 | print(args) 34 | 35 | path = "src/config/" + args.config 36 | with open(path, "r") as f: 37 | config = yaml.safe_load(f) 38 | 39 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 40 | 41 | print(json.dumps(config, indent=4)) 42 | 43 | set_seed(args.seed) 44 | 45 | 46 | foldername = "./save/Anomaly_Detection/" + args.dataset + "/run_" + str(args.run) +"/" 47 | model = TSDE_AD(target_dim = config["embedding"]["num_feat"], config = config, device = args.device).to(args.device) 48 | train_loader, valid_loader, test_loader = anomaly_detection_dataloader(dataset_name = args.dataset, batch_size = config["train"]["batch_size"]) 49 | anomaly_ratio = args.anomaly_ratio 50 | 51 | print('model folder:', foldername) 52 | os.makedirs(foldername, exist_ok=True) 53 | 54 | with open(foldername + "config.json", "w") as f: 55 | json.dump(config, f, indent=4) 56 | 57 | if args.modelfolder == "": 58 | loss_path = foldername + "/losses.txt" 59 | with open(loss_path, "a") as file: 60 | file.write("Pretraining"+"\n") 61 | ## Pre-training 62 | train(model, config["train"], train_loader, foldername=foldername, normalize_for_ad=True) 63 | else: 64 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 65 | 66 | 67 | if not args.disable_finetune: 68 | for param in model.parameters(): 69 | param.requires_grad = False 70 | for param in model.conv.parameters(): 71 | param.requires_grad = True 72 | 73 | for name, param in model.named_parameters(): 74 | print(f"{name}: {param.requires_grad}") 75 | 76 | finetune(model, config["finetuning"], train_loader, criterion = nn.MSELoss(), foldername=foldername, task='anomaly_detection', normalize_for_ad=True) 77 | evaluate_finetuning(model, train_loader, test_loader, anomaly_ratio = anomaly_ratio, foldername=foldername, task='anomaly_detection', normalize_for_ad=True) 78 | -------------------------------------------------------------------------------- /src/experiments/train_test_tslib_forecasting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | 8 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 9 | sys.path.append(root_dir) 10 | 11 | # Import necessary modules from your project 12 | from data_loader.elec_tslib_dataloader import get_dataloader_elec 13 | from tsde.main_model import TSDE_Forecasting 14 | from utils.utils import train, evaluate, gsutil_cp, set_seed 15 | 16 | def main(): 17 | # Command line arguments for configuration, could be expanded as needed 18 | parser = argparse.ArgumentParser(description="Run forecasting model") 19 | parser.add_argument('--config', type=str, default='base_forecasting.yaml', help='Path to configuration yaml') 20 | parser.add_argument('--device', type=str, default='cuda:1', help='Device to run the model on') 21 | parser.add_argument('--nsample', type=int, default=100, help='Number of samples') 22 | parser.add_argument('--hist_length', type=int, default=96, help='History window length') 23 | parser.add_argument('--pred_length', type=int, default=192, help='Prediction window length') 24 | parser.add_argument('--run', type=int, default=100200000, help='Run identifier') 25 | parser.add_argument('--linear', action='store_true', help='Linear mode flag') 26 | parser.add_argument('--sample_feat', action='store_true', help='Sample feature flag') 27 | parser.add_argument('--seed', type=int, default=1, help='Seed for random number generation') 28 | parser.add_argument('--load', type=str, default=None, help='Path to pretrained model') 29 | args = parser.parse_args() 30 | 31 | # Set seed for reproducibility 32 | set_seed(args.seed) 33 | 34 | # Load config 35 | path = "src/config/" + args.config 36 | with open(path, "r") as f: 37 | config = yaml.safe_load(f) 38 | 39 | # Setup model folder path 40 | foldername = f"./save/Forecasting/TSLIB_Elec/n_samples_{args.nsample}_run_{args.run}_linear_{args.linear}_sample_feat_{args.sample_feat}/" 41 | os.makedirs(foldername, exist_ok=True) 42 | 43 | # Save configuration to the model folder 44 | with open(os.path.join(foldername, "config.json"), "w") as f: 45 | json.dump(config, f, indent=4) 46 | 47 | # Model setup 48 | model = TSDE_Forecasting(config, args.device, target_dim=321, sample_feat=args.sample_feat).to(args.device) 49 | train_loader, valid_loader, test_loader, _ = get_dataloader_elec( 50 | pred_length=args.pred_length, 51 | history_length=args.hist_length, 52 | batch_size=config["train"]["batch_size"], 53 | device=args.device, 54 | ) 55 | if args.load is None: 56 | # Start training 57 | train( 58 | model, 59 | config["train"], 60 | train_loader, 61 | valid_loader=valid_loader, 62 | test_loader=test_loader, 63 | foldername=foldername, 64 | nsample=args.nsample, 65 | scaler=1, 66 | mean_scaler=0, 67 | eval_epoch_interval=20000, 68 | ) 69 | else: 70 | model.load_state_dict(torch.load("./save/" + args.load + "/model.pth", map_location=args.device)) 71 | 72 | evaluate(model, test_loader, nsample=args.nsample, foldername=foldername) 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Data 156 | data 157 | 158 | # Results 159 | save 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /src/experiments/train_test_interpolation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | 8 | 9 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 10 | sys.path.append(root_dir) 11 | 12 | 13 | from data_loader.physio_dataloader import get_dataloader_physio 14 | from tsde.main_model import TSDE_Physio 15 | from utils.utils import train, evaluate, gsutil_cp, set_seed 16 | 17 | torch.backends.cudnn.enabled = False 18 | 19 | parser = argparse.ArgumentParser(description="TSDE-Interpolation") 20 | parser.add_argument("--config", type=str, default="base.yaml") 21 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 22 | parser.add_argument("--seed", type=int, default=1) 23 | 24 | 25 | parser.add_argument("--dataset", type=str, default='PhysioNet') 26 | parser.add_argument("--modelfolder", type=str, default="") 27 | parser.add_argument("--nsample", type=int, default=100) 28 | parser.add_argument("--run", type=int, default=1) 29 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 30 | 31 | 32 | ## Args for physio 33 | parser.add_argument("--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])") 34 | parser.add_argument("--testmissingratio", type=float, default=0.1) 35 | parser.add_argument("--physionet_classification", type=bool, default=False) 36 | 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | path = "src/config/" + args.config 41 | with open(path, "r") as f: 42 | config = yaml.safe_load(f) 43 | 44 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 45 | config["model"]["test_missing_ratio"] = args.testmissingratio 46 | 47 | print(json.dumps(config, indent=4)) 48 | 49 | set_seed(args.seed) 50 | 51 | if args.dataset == "PhysioNet": 52 | foldername = "./save/Interpolation/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_missing_ratio_' + str(args.testmissingratio) +"/" 53 | model = TSDE_Physio(config, args.device).to(args.device) 54 | train_loader, valid_loader, test_loader = get_dataloader_physio(seed=args.seed, nfold=args.nfold, batch_size=config["train"]["batch_size"], missing_ratio=config["model"]["test_missing_ratio"], mode='interpolation') 55 | scaler = 1 56 | mean_scaler = 0 57 | mode = 'Interpolation' 58 | else: 59 | print() 60 | 61 | print('model folder:', foldername) 62 | os.makedirs(foldername, exist_ok=True) 63 | 64 | with open(foldername + "config.json", "w") as f: 65 | json.dump(config, f, indent=4) 66 | 67 | if args.modelfolder == "": 68 | loss_path = foldername + "/losses.txt" 69 | with open(loss_path, "a") as file: 70 | file.write("Pretraining"+"\n") 71 | ## Pre-training 72 | train(model, config["train"], train_loader, valid_loader=valid_loader, test_loader=test_loader, foldername=foldername, nsample=args.nsample, 73 | scaler=scaler, mean_scaler=mean_scaler, eval_epoch_interval=500,physionet_classification=args.physionet_classification) 74 | if config["finetuning"]["epochs"]!=0: 75 | print('Finetuning') 76 | ## Fine Tuning 77 | with open(loss_path, "a") as file: 78 | file.write("Finetuning"+"\n") 79 | checkpoint_path = foldername + "model.pth" 80 | model.load_state_dict(torch.load(checkpoint_path)) 81 | config["train"]["epochs"]=config["finetuning"]["epochs"] 82 | train( 83 | model, 84 | config["train"], 85 | train_loader, 86 | valid_loader=valid_loader, 87 | foldername=foldername, 88 | mode = mode, 89 | ) 90 | 91 | else: 92 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 93 | 94 | evaluate( 95 | model, 96 | test_loader, 97 | nsample=args.nsample, 98 | scaler=scaler, 99 | mean_scaler=mean_scaler, 100 | foldername=foldername, 101 | save_samples = True, 102 | physionet_classification=args.physionet_classification 103 | 104 | ) 105 | 106 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import auc 4 | 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def quantile_loss(target, forecast, q: float, eval_points) -> float: 10 | """ 11 | Calculates the quantile loss for a given forecast and target values based on a specific quantile. 12 | 13 | Parameters: 14 | - target: Torch tensor containing the observed values. 15 | - forecast: Torch tensor containing the predicted values. 16 | - q: Float, representing the quantile for which the loss is calculated. 17 | - eval_points: Torch tensor representing the evaluation points for the calculation. 18 | 19 | Returns: 20 | - Float representing the calculated quantile loss. 21 | """ 22 | return 2 * torch.sum( 23 | torch.abs((forecast - target) * eval_points * ((target <= forecast) * 1.0 - q)) 24 | ) 25 | 26 | 27 | def calc_denominator(target, eval_points): 28 | """ 29 | Calculates the denominator used in the CRPS and CRPS-sum calculation, based on the absolute sum of the target values used for evaluation. 30 | 31 | Parameters: 32 | - target: Torch tensor containing the target values. 33 | - eval_points: Torch tensor representing the evaluation points for the calculation. 34 | 35 | Returns: 36 | - Torch tensor representing the denominator value. 37 | """ 38 | return torch.sum(torch.abs(target * eval_points)) 39 | 40 | 41 | def calc_quantile_CRPS(target, forecast, eval_points, mean_scaler, scaler): 42 | """ 43 | Calculates the CRPS based on quantile loss for multiple quantiles. 44 | 45 | Parameters: 46 | - target: Torch tensor containing the target values. 47 | - forecast: Torch tensor containing the predicted values. 48 | - eval_points: Torch tensor representing the evaluation points for the calculation. 49 | - mean_scaler: Float, the mean value used for scaling the target and forecast back to their original values. 50 | - scaler: Float, the scale value used for scaling the target and forecast back to their original values. 51 | 52 | Returns: 53 | - Float representing the calculated CRPS. 54 | """ 55 | target = target * scaler + mean_scaler 56 | forecast = forecast * scaler + mean_scaler 57 | 58 | quantiles = np.arange(0.05, 1.0, 0.05) 59 | denom = calc_denominator(target, eval_points) 60 | CRPS = 0 61 | for i in range(len(quantiles)): 62 | q_pred = [] 63 | for j in range(len(forecast)): 64 | q_pred.append(torch.quantile(forecast[j : j + 1], quantiles[i], dim=1)) 65 | q_pred = torch.cat(q_pred, 0) 66 | q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points) 67 | CRPS += q_loss / denom 68 | return CRPS.item() / len(quantiles) 69 | 70 | def calc_quantile_CRPS_sum(target, forecast, eval_points, mean_scaler, scaler): 71 | """ 72 | Calculates the CRPS for the sum of the target and predicted values across all features. 73 | 74 | Parameters: 75 | - target: Torch tensor containing the target values. 76 | - forecast: Torch tensor containing the predicted values. 77 | - eval_points: Torch tensor representing the evaluation points for the calculation. 78 | - mean_scaler: Float, the mean value used for scaling the target and predictions back to their original values. 79 | - scaler: Float, the scale value used for scaling the target and predictions back to their original values. 80 | 81 | Returns: 82 | - Float representing the calculated CRPS-sum. 83 | """ 84 | target = target * scaler + mean_scaler 85 | forecast = forecast * scaler + mean_scaler 86 | target_sum = torch.sum(target, dim = 2).unsqueeze(2) 87 | forecast_sum = torch.sum(forecast, dim = 3).unsqueeze(3) 88 | eval_points_sum = torch.mean(eval_points, dim=2).unsqueeze(2) 89 | 90 | crps_sum = calc_quantile_CRPS(target_sum, forecast_sum, eval_points_sum, 0, 1) 91 | return crps_sum 92 | 93 | 94 | def save_roc_curve(fpr, tpr, foldername): 95 | """ 96 | Generates and saves an ROC curve to the specified file path. 97 | 98 | Parameters: 99 | y_true (array-like): True binary labels. 100 | y_scores (array-like): Scores assigned by the classifier. 101 | file_path (str): Path where the ROC curve image will be saved. 102 | 103 | Returns: 104 | str: The file path where the image was saved. 105 | """ 106 | 107 | roc_auc = auc(fpr, tpr) 108 | 109 | # Plotting and saving the ROC curve 110 | plt.figure() 111 | plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') 112 | plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') 113 | plt.xlim([0.0, 1.0]) 114 | plt.ylim([0.0, 1.05]) 115 | plt.xlabel('False Positive Rate') 116 | plt.ylabel('True Positive Rate') 117 | plt.title('Receiver Operating Characteristic') 118 | plt.legend(loc="lower right") 119 | plt.savefig(foldername+'roc.png') 120 | plt.close() -------------------------------------------------------------------------------- /src/experiments/train_test_imputation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | import numpy as np 8 | from multiprocessing import freeze_support 9 | 10 | 11 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 12 | sys.path.append(root_dir) 13 | 14 | from data_loader.pm25_dataloader import get_dataloader_pm25 15 | from data_loader.physio_dataloader import get_dataloader_physio 16 | from tsde.main_model import TSDE_PM25, TSDE_Physio 17 | from utils.utils import train, evaluate, gsutil_cp, set_seed 18 | 19 | torch.backends.cudnn.enabled = False 20 | 21 | parser = argparse.ArgumentParser(description="TSDE-Imputation") 22 | parser.add_argument("--config", type=str, default="base.yaml") 23 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 24 | parser.add_argument("--seed", type=int, default=1) 25 | 26 | parser.add_argument("--dataset", type=str, default='Pm25') 27 | parser.add_argument("--modelfolder", type=str, default="") 28 | parser.add_argument("--nsample", type=int, default=100) 29 | parser.add_argument("--run", type=int, default=0) 30 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 31 | 32 | 33 | ## Args for pm25 34 | parser.add_argument("--validationindex", type=int, default=0, help="index of month used for validation (value:[0-7])") 35 | ## Args for physio 36 | parser.add_argument("--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])") 37 | parser.add_argument("--testmissingratio", type=float, default=0.1) 38 | parser.add_argument("--physionet_classification", type=bool, default=False) 39 | 40 | args = parser.parse_args() 41 | print(args) 42 | 43 | path = "src/config/" + args.config 44 | with open(path, "r") as f: 45 | config = yaml.safe_load(f) 46 | 47 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 48 | config["model"]["test_missing_ratio"] = args.testmissingratio 49 | 50 | print(json.dumps(config, indent=4)) 51 | 52 | set_seed(args.seed) 53 | 54 | 55 | if args.dataset == "Pm25": 56 | foldername = "./save/Imputation/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_validationindex_' + str(args.validationindex) +"/" 57 | model = TSDE_PM25(config, args.device).to(args.device) 58 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_pm25(config["train"]["batch_size"], device=args.device, validindex=args.validationindex) 59 | mode = 'Imputation with pattern' 60 | 61 | elif args.dataset == "PhysioNet": 62 | foldername = "./save/Imputation/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_missing_ratio_' + str(args.testmissingratio) +"/" 63 | model = TSDE_Physio(config, args.device).to(args.device) 64 | train_loader, valid_loader, test_loader = get_dataloader_physio(seed=args.seed, nfold=args.nfold, batch_size=config["train"]["batch_size"], missing_ratio=config["model"]["test_missing_ratio"]) 65 | scaler = 1 66 | mean_scaler = 0 67 | mode = 'Imputation' 68 | else: 69 | print() 70 | 71 | print('model folder:', foldername) 72 | os.makedirs(foldername, exist_ok=True) 73 | 74 | with open(foldername + "config.json", "w") as f: 75 | json.dump(config, f, indent=4) 76 | 77 | 78 | 79 | def main(): 80 | if args.modelfolder == "": 81 | loss_path = foldername + "/losses.txt" 82 | with open(loss_path, "a") as file: 83 | file.write("Pretraining"+"\n") 84 | 85 | ## Pre-training 86 | train(model, config["train"], train_loader, valid_loader=valid_loader, test_loader=test_loader, foldername=foldername, nsample=args.nsample, 87 | scaler=scaler, mean_scaler=mean_scaler, eval_epoch_interval=500,physionet_classification=args.physionet_classification) 88 | 89 | if config["finetuning"]["epochs"]!=0: 90 | print('Finetuning') 91 | ## Fine Tuning 92 | with open(loss_path, "a") as file: 93 | file.write("Finetuning"+"\n") 94 | checkpoint_path = foldername + "model.pth" 95 | model.load_state_dict(torch.load(checkpoint_path)) 96 | config["train"]["epochs"]=config["finetuning"]["epochs"] 97 | train( 98 | model, 99 | config["train"], 100 | train_loader, 101 | valid_loader=valid_loader, 102 | foldername=foldername, 103 | mode = mode, 104 | ) 105 | 106 | else: 107 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 108 | 109 | evaluate( 110 | model, 111 | test_loader, 112 | nsample=args.nsample, 113 | scaler=scaler, 114 | mean_scaler=mean_scaler, 115 | foldername=foldername, 116 | save_samples = True, 117 | physionet_classification=args.physionet_classification 118 | ) 119 | 120 | 121 | 122 | 123 | if __name__ == '__main__': 124 | #freeze_support() # Recommended for Windows if you plan to freeze your script 125 | main() 126 | -------------------------------------------------------------------------------- /src/experiments/train_test_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | import torch.nn as nn 8 | 9 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 10 | sys.path.append(root_dir) 11 | 12 | from data_loader.physio_dataloader import get_dataloader_physio, get_physio_dataloader_for_classification 13 | from tsde.main_model import TSDE_Physio 14 | from utils.utils import train, evaluate, gsutil_cp, set_seed, finetune, evaluate_finetuning 15 | 16 | torch.backends.cudnn.enabled = False 17 | 18 | parser = argparse.ArgumentParser(description="TSDE-Imputation-Classification") 19 | parser.add_argument("--config", type=str, default="base_classification.yaml") 20 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 21 | parser.add_argument("--seed", type=int, default=1) 22 | 23 | parser.add_argument("--modelfolder", type=str, default="") 24 | parser.add_argument("--nsample", type=int, default=100) 25 | parser.add_argument("--run", type=int, default=1) 26 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal p or probabilistic layering)") 27 | parser.add_argument("--disable_finetune", action="store_true") 28 | 29 | ## Args for physio 30 | parser.add_argument("--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])") 31 | parser.add_argument("--testmissingratio", type=float, default=0.1) 32 | parser.add_argument("--physionet_classification", type=bool, default=True) 33 | 34 | args = parser.parse_args() 35 | print(args) 36 | 37 | path = "src/config/" + args.config 38 | with open(path, "r") as f: 39 | config = yaml.safe_load(f) 40 | 41 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 42 | config["model"]["test_missing_ratio"] = args.testmissingratio 43 | 44 | print(json.dumps(config, indent=4)) 45 | 46 | set_seed(args.seed) 47 | 48 | 49 | foldername = "./save/Imputation-Classification/" + 'PhysioNet' + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_missing_ratio_' + str(args.testmissingratio) +"/" 50 | model = TSDE_Physio(config, args.device).to(args.device) 51 | train_loader, valid_loader, test_loader = get_dataloader_physio(seed=args.seed, nfold=args.nfold, batch_size=config["train"]["batch_size"], missing_ratio=config["model"]["test_missing_ratio"]) 52 | scaler = 1 53 | mean_scaler = 0 54 | mode = 'Imputation' 55 | 56 | 57 | print('model folder:', foldername) 58 | os.makedirs(foldername, exist_ok=True) 59 | 60 | with open(foldername + "config.json", "w") as f: 61 | json.dump(config, f, indent=4) 62 | 63 | if args.modelfolder == "": 64 | os.makedirs(foldername+'Pretrained/', exist_ok=True) 65 | loss_path = foldername + "losses.txt" 66 | with open(loss_path, "a") as file: 67 | file.write("Pretraining"+"\n") 68 | ## Pre-training 69 | train(model, config["train"], train_loader, valid_loader=valid_loader, test_loader=test_loader, foldername=foldername+'Pretrained/', nsample=args.nsample, 70 | scaler=scaler, mean_scaler=mean_scaler, eval_epoch_interval=100000,physionet_classification=args.physionet_classification) 71 | 72 | ## Save imputed time series in Train, Validation and Test sets 73 | evaluate( 74 | model, 75 | train_loader, 76 | nsample=args.nsample, 77 | scaler=scaler, 78 | mean_scaler=mean_scaler, 79 | foldername=foldername+'Pretrained/', 80 | save_samples = True, 81 | physionet_classification=True, 82 | set_type = 'Train' 83 | ) 84 | 85 | evaluate( 86 | model, 87 | valid_loader, 88 | nsample=args.nsample, 89 | scaler=scaler, 90 | mean_scaler=mean_scaler, 91 | foldername=foldername+'Pretrained/', 92 | save_samples = True, 93 | physionet_classification=True, 94 | set_type = 'Val' 95 | ) 96 | evaluate( 97 | model, 98 | test_loader, 99 | nsample=args.nsample, 100 | scaler=scaler, 101 | mean_scaler=mean_scaler, 102 | foldername=foldername+'Pretrained/', 103 | save_samples = True, 104 | physionet_classification=True, 105 | set_type = 'Test' 106 | ) 107 | 108 | 109 | ## Prepare dataloaders for classification head finetuning 110 | train_loader_classification, valid_loader_classification, test_loader_classification = get_physio_dataloader_for_classification(filename=foldername+'Pretrained/', batch_size=config["train"]["batch_size"]) 111 | model.load_state_dict(torch.load(foldername + "Pretrained/model.pth", map_location=args.device)) 112 | 113 | 114 | else: 115 | # Load pretrained and dataloaders of imputed MTS 116 | train_loader_classification, valid_loader_classification, test_loader_classification = get_physio_dataloader_for_classification(filename=args.modelfolder+"Pretrained/", batch_size=config["train"]["batch_size"]) 117 | model.load_state_dict(torch.load(args.modelfolder + "Pretrained/model.pth", map_location=args.device)) 118 | 119 | print(args.disable_finetune) 120 | if not args.disable_finetune: 121 | for name, param in model.named_parameters(): 122 | # Freeze all parameters 123 | param.requires_grad = False 124 | 125 | 126 | for name, param in model.mlp.named_parameters(): 127 | param.requires_grad = True 128 | for name, param in model.named_parameters(): 129 | print(f"{name}: {param.requires_grad}") 130 | 131 | ## Finetune the classifier head 132 | finetune(model, config["finetuning"], train_loader_classification, criterion = nn.CrossEntropyLoss(), foldername=foldername) 133 | else: 134 | model.load_state_dict(torch.load(args.modelfolder + "model.pth", map_location=args.device)) 135 | 136 | 137 | ## Evaluate the classification 138 | evaluate_finetuning(model, train_loader_classification, test_loader_classification, foldername=foldername) 139 | 140 | 141 | -------------------------------------------------------------------------------- /src/base/denoisingNetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from base.diffEmbedding import DiffusionEmbedding 7 | 8 | 9 | def Conv1d_with_init(in_channels, out_channels, kernel_size): 10 | """ 11 | Initializes a 1D convolutional layer with Kaiming normal initialization. 12 | 13 | Parameters: 14 | - in_channels (int): Number of channels in the input signal. 15 | - out_channels (int): Number of channels produced by the convolution. 16 | - kernel_size (int): Size of the convolving kernel. 17 | 18 | Returns: 19 | - nn.Conv1d: A 1D convolutional layer with weights initialized. 20 | """ 21 | layer = nn.Conv1d(in_channels, out_channels, kernel_size) 22 | nn.init.kaiming_normal_(layer.weight) 23 | return layer 24 | 25 | 26 | class diff_Block(nn.Module): 27 | """ 28 | A neural network block that incorporates diffusion embedding, designed for the reverse pass in DDPM. It corresponds to the denoising block in the TSDE architecture. 29 | 30 | Parameters: 31 | - config (dict): Configuration dictionary containing model settings for the denoising block. 32 | """ 33 | def __init__(self, config): 34 | super().__init__() 35 | 36 | self.channels = config["channels"] 37 | 38 | self.diffusion_embedding = DiffusionEmbedding( 39 | num_steps=config["num_steps"], 40 | embedding_dim=config["diffusion_embedding_dim"], 41 | ) 42 | 43 | self.input_projection = Conv1d_with_init(1, self.channels, 1) 44 | self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1) 45 | self.output_projection2 = Conv1d_with_init(self.channels, 1, 1) 46 | nn.init.zeros_(self.output_projection2.weight) 47 | 48 | self.residual_layers = nn.ModuleList( 49 | [ 50 | ResidualBlock( 51 | mts_emb_dim=config["mts_emb_dim"], 52 | channels=self.channels, 53 | diffusion_embedding_dim=config["diffusion_embedding_dim"], 54 | ) 55 | for _ in range(config["layers"]) 56 | ] 57 | ) 58 | 59 | def forward(self, x, mts_emb, diffusion_step): 60 | """ 61 | Forward pass of the denoising block. 62 | 63 | Parameters: 64 | - x (Tensor): The corrupted input MTS to be denoised. 65 | - mts_emb (Tensor): The embedding of the observed part of the MTS. 66 | - diffusion_step (Tensor): The current diffusion step index. 67 | 68 | Returns: 69 | - Tensor: The output tensor of the predicted noise added in x at diffusion_step. 70 | """ 71 | B, inputdim, K, L = x.shape 72 | 73 | x = x.reshape(B, inputdim, K * L) 74 | x = self.input_projection(x) ## First Convolution before fedding the data to the 75 | x = F.relu(x) ## residual block 76 | x = x.reshape(B, self.channels, K, L) 77 | 78 | diffusion_emb = self.diffusion_embedding(diffusion_step) 79 | 80 | skip = [] 81 | for layer in self.residual_layers: 82 | x, skip_connection = layer(x, mts_emb, diffusion_emb) 83 | skip.append(skip_connection) 84 | 85 | x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) 86 | x = x.reshape(B, self.channels, K * L) 87 | x = self.output_projection1(x) # (B,channel,K*L) 88 | x = F.relu(x) 89 | x = self.output_projection2(x) # (B,1,K*L) 90 | x = x.reshape(B, K, L) 91 | return x 92 | 93 | 94 | class ResidualBlock(nn.Module): 95 | """ 96 | A residual block that processes input data alongside diffusion embeddings and observed MTS part embedding, 97 | utilizing a gated mechanism. 98 | 99 | Parameters: 100 | - mts_emb_dim (int): Dimensionality of the embedding of the MTS 101 | - channels (int): Number of channels for the convolutional layers within the block. 102 | - diffusion_embedding_dim (int): Dimensionality of the diffusion embeddings. 103 | 104 | """ 105 | def __init__(self, mts_emb_dim, channels, diffusion_embedding_dim): 106 | super().__init__() 107 | self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels) 108 | self.cond_projection = Conv1d_with_init(mts_emb_dim, 2 * channels, 1) 109 | self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1) 110 | self.output_projection = Conv1d_with_init(channels, 2 * channels, 1) 111 | 112 | 113 | def forward(self, x, mts_emb, diffusion_emb): 114 | """ 115 | Forward pass of the ResidualBlock. 116 | 117 | Parameters: 118 | - x (Tensor): The projected corrupted input MTS. 119 | - mts_emb (Tensor): The embedding of the observed part of the MTS. 120 | - diffusion_emb (Tensor): The projected diffusion embedding tensor. 121 | 122 | Returns: 123 | - Tuple[Tensor, Tensor]: A tuple containing the updated data tensor and a skip connection tensor. 124 | """ 125 | 126 | B, channel, K, L = x.shape 127 | base_shape = x.shape 128 | x = x.reshape(B, channel, K * L) 129 | 130 | diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1) # (B,channel,1) 131 | y = x + diffusion_emb 132 | 133 | y = self.mid_projection(y) # (B,2*channel,K*L) 134 | 135 | _, mts_emb_dim, _, _ = mts_emb.shape 136 | mts_emb = mts_emb.reshape(B, mts_emb_dim, K * L) #B, C, K*L 137 | mts_emb = self.cond_projection(mts_emb) # (B,2*channel,K*L) 138 | y = y + mts_emb 139 | 140 | gate, filter = torch.chunk(y, 2, dim=1) 141 | y = torch.sigmoid(gate) * torch.tanh(filter) # (B,channel,K*L) 142 | y = self.output_projection(y) 143 | 144 | residual, skip = torch.chunk(y, 2, dim=1) 145 | x = x.reshape(base_shape) 146 | residual = residual.reshape(base_shape) 147 | skip = skip.reshape(base_shape) 148 | return (x + residual) / math.sqrt(2.0), skip -------------------------------------------------------------------------------- /src/experiments/train_test_forecasting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import yaml 5 | import os 6 | import sys 7 | 8 | 9 | root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 10 | sys.path.append(root_dir) 11 | 12 | from data_loader.forecasting_dataloader import get_dataloader_forecasting 13 | from tsde.main_model import TSDE_Forecasting 14 | from utils.utils import train, evaluate, gsutil_cp, set_seed 15 | 16 | torch.backends.cudnn.enabled = False 17 | 18 | parser = argparse.ArgumentParser(description="TSDE-Forecasting") 19 | parser.add_argument("--config", type=str, default="base_forecasting.yaml") 20 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 21 | parser.add_argument("--seed", type=int, default=1) 22 | parser.add_argument('--linear', action='store_true', help='Linear mode flag') 23 | parser.add_argument('--sample_feat', action='store_true', help='Sample feature flag') 24 | 25 | 26 | parser.add_argument("--dataset", type=str, default='Electricity') 27 | parser.add_argument("--modelfolder", type=str, default="") 28 | parser.add_argument("--nsample", type=int, default=100) 29 | parser.add_argument("--run", type=int, default=1) 30 | parser.add_argument("--mix_masking_strategy", type=str, default='equal_p', help="Mix masking strategy (equal_p or probabilistic_layering)") 31 | 32 | args = parser.parse_args() 33 | print(args) 34 | 35 | path = "src/config/" + args.config 36 | with open(path, "r") as f: 37 | config = yaml.safe_load(f) 38 | 39 | config["model"]["mix_masking_strategy"] = args.mix_masking_strategy 40 | 41 | 42 | print(json.dumps(config, indent=4)) 43 | 44 | set_seed(args.seed) 45 | 46 | if args.dataset == "Electricity": 47 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 48 | model = TSDE_Forecasting(config, args.device, target_dim=370, sample_feat=args.sample_feat).to(args.device) 49 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 50 | dataset_name='electricity', 51 | train_length=5833, 52 | skip_length=370*6, 53 | batch_size=config["train"]["batch_size"], 54 | device= args.device, 55 | ) 56 | 57 | elif args.dataset == "Solar": 58 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 59 | model = TSDE_Forecasting(config, args.device, target_dim=137, sample_feat=args.sample_feat).to(args.device) 60 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 61 | dataset_name='solar', 62 | train_length=7009, 63 | skip_length=137*6, 64 | batch_size=config["train"]["batch_size"], 65 | device= args.device, 66 | ) 67 | 68 | elif args.dataset == "Traffic": 69 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 70 | model = TSDE_Forecasting(config, args.device, target_dim=963, sample_feat=args.sample_feat).to(args.device) 71 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 72 | dataset_name='traffic', 73 | train_length=4001, 74 | skip_length=963*6, 75 | batch_size=config["train"]["batch_size"], 76 | device= args.device, 77 | ) 78 | 79 | elif args.dataset == "Taxi": 80 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 81 | model = TSDE_Forecasting(config, args.device, target_dim=1214, sample_feat=args.sample_feat).to(args.device) 82 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 83 | dataset_name='taxi', 84 | train_length=1488, 85 | skip_length=1214*55, 86 | test_length=24*56, 87 | history_length=48, 88 | batch_size=config["train"]["batch_size"], 89 | device= args.device, 90 | ) 91 | 92 | elif args.dataset == "Wiki": 93 | foldername = "./save/Forecasting/" + args.dataset + "/n_samples_" + str(args.nsample) + "_run_" + str(args.run) + '_linear_' + str(args.linear) + '_sample_feat_' + str(args.sample_feat)+"/" 94 | model = TSDE_Forecasting(config, args.device, target_dim=2000, sample_feat=args.sample_feat).to(args.device) 95 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader_forecasting( 96 | dataset_name='wiki', 97 | train_length=792, 98 | skip_length=9535*4, 99 | test_length=30*5, 100 | valid_length=30*5, 101 | history_length=90, 102 | pred_length=30, 103 | batch_size=config["train"]["batch_size"], 104 | device= args.device, 105 | ) 106 | 107 | else: 108 | print() 109 | 110 | 111 | print('model folder:', foldername) 112 | os.makedirs(foldername, exist_ok=True) 113 | 114 | 115 | with open(foldername + "config.json", "w") as f: 116 | json.dump(config, f, indent=4) 117 | 118 | 119 | if args.modelfolder == "": 120 | loss_path = foldername + "/losses.txt" 121 | with open(loss_path, "a") as file: 122 | file.write("Pretraining"+"\n") 123 | train( 124 | model, 125 | config["train"], 126 | train_loader, 127 | valid_loader=valid_loader, 128 | test_loader=test_loader, 129 | foldername=foldername, 130 | nsample=args.nsample, 131 | scaler=scaler, 132 | mean_scaler=mean_scaler, 133 | eval_epoch_interval=200, 134 | ) 135 | if config["finetuning"]["epochs"]!=0: 136 | print("Finetuning") 137 | ## Fine Tuning 138 | with open(loss_path, "a") as file: 139 | file.write("Finetuning"+"\n") 140 | checkpoint_path = foldername + "model.pth" 141 | model.load_state_dict(torch.load(checkpoint_path)) 142 | config["train"]["epochs"]=config["finetuning"]["epochs"] 143 | train( 144 | model, 145 | config["train"], 146 | train_loader, 147 | valid_loader=valid_loader, 148 | foldername=foldername, 149 | mode = 'Forecasting', 150 | ) 151 | 152 | else: 153 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth", map_location=args.device)) 154 | 155 | evaluate(model, test_loader, nsample=args.nsample, foldername=foldername, scaler=scaler, mean_scaler=mean_scaler,save_samples = True) 156 | 157 | -------------------------------------------------------------------------------- /src/base/mtsEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def get_torch_trans(heads=8, layers=1, channels=64): 6 | """ 7 | Creates a Transformer encoder module to process MTS timestamps/features as sequences. 8 | 9 | Parameters: 10 | - heads (int): Number of attention heads. 11 | - layers (int): Number of encoder layers. 12 | - channels (int): Dimensionality of the model (d_model in Transformer terminology). 13 | 14 | Returns: 15 | - nn.TransformerEncoder: A Transformer encoder object. 16 | """ 17 | encoder_layer = nn.TransformerEncoderLayer( 18 | d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu" 19 | ) 20 | return nn.TransformerEncoder(encoder_layer, num_layers=layers, enable_nested_tensor=False) 21 | 22 | 23 | def Conv1d_with_init(in_channels, out_channels, kernel_size): 24 | """ 25 | Initializes a 1D convolutional layer with Kaiming normal initialization. 26 | 27 | Parameters: 28 | - in_channels (int): Number of channels in the input signal. 29 | - out_channels (int): Number of channels produced by the convolution. 30 | - kernel_size (int): Size of the convolving kernel. 31 | 32 | Returns: 33 | - nn.Conv1d: A 1D convolutional layer with weights initialized. 34 | """ 35 | layer = nn.Conv1d(in_channels, out_channels, kernel_size) 36 | nn.init.kaiming_normal_(layer.weight) 37 | return layer 38 | 39 | class embedding_MTS(nn.Module): 40 | """ 41 | An embedding module for multivariate time series data, incorporating both temporal and spatial 42 | encoding via Transformer encoders and convolutional layers, it corresponds to the Embedding block in TSDE architecture. 43 | 44 | Parameters: 45 | - config (dict): A configuration dictionary containing model parameters such as number of channels, 46 | embedding dimensions, and number of heads for the Transformer encoders in the embedding block. 47 | """ 48 | 49 | 50 | def __init__(self, config): 51 | super().__init__() 52 | self.channels = config["channels"] 53 | self.timeemb = config["timeemb"] 54 | self.featureemb = config["featureemb"] 55 | self.emb_size = self.channels + self.timeemb + self.featureemb 56 | self.time_layer = get_torch_trans(heads=config["nheads"], layers=1, channels=self.emb_size) ## Temporal encoder 57 | self.feature_layer = get_torch_trans(heads=config["nheads"], layers=1, channels=self.emb_size) ## Spatial encoder 58 | 59 | self.input_projection = Conv1d_with_init(1, self.channels, 1) 60 | 61 | self.xt_projection = Conv1d_with_init(self.emb_size, self.channels, 1) 62 | self.xf_projection = Conv1d_with_init(self.emb_size, self.channels, 1) 63 | 64 | 65 | def forward_time(self, x, base_shape): 66 | """ 67 | Processes the input data through the temporal Transformer encoder (timestamps are considered as tokens). 68 | 69 | Parameters: 70 | - x (Tensor): The observed part of the MTS combined with timestamps embedding and features embedding as tensor. 71 | - base_shape (tuple): The base shape of the input tensor before reshaping. 72 | 73 | Returns: 74 | - Tensor: The temporally encoded tensor. 75 | """ 76 | B, C, K, L = base_shape 77 | if L == 1: 78 | return x 79 | x = x.permute(0, 2, 1, 3).reshape(B * K, C, L) 80 | x = self.time_layer(x.permute(2, 0, 1)).permute(1, 2, 0) 81 | x = x.reshape(B, K, C, L).permute(0, 2, 1, 3) 82 | return x 83 | 84 | def forward_feature(self, x, base_shape): 85 | """ 86 | Processes the input data through the spatial Transformer encoder (features are considered as tokens). 87 | 88 | Parameters: 89 | - x (Tensor): The observed part of the MTS combined with timestamps embedding and features embedding as tensor. 90 | - base_shape (tuple): The base shape of the input tensor before reshaping. 91 | 92 | Returns: 93 | - Tensor: The spatially encoded tensor. 94 | """ 95 | B, C, K, L = base_shape 96 | if K == 1: 97 | return x 98 | x = x.permute(0, 3, 1, 2).reshape(B * L, C, K) 99 | x = self.feature_layer(x.permute(2, 0, 1)).permute(1, 2, 0) 100 | x = x.reshape(B, L, C, K).permute(0, 2, 3, 1) 101 | return x 102 | 103 | def forward(self, x, time_embed, feature_embed): 104 | """ 105 | The forward pass of the embedding module, processing multivariate time series data with 106 | temporal and spatial embeddings. 107 | 108 | Parameters: 109 | - x (Tensor): The observed part of the MTS. 110 | - time_embed (Tensor): The time embeddings tensor. 111 | - feature_embed (Tensor): The feature embeddings tensor. 112 | 113 | Returns: 114 | - Tuple[Tensor, Tensor, Tensor]: A tuple containing the combined temporal and spatial embeddings generated by the two transformer encoders, 115 | the processed temporal embedding, and the processed spatial embedding. 116 | """ 117 | B, _, K, L = x.shape 118 | time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) 119 | feature_embed = feature_embed.unsqueeze(1).expand(-1, L, -1, -1) 120 | 121 | x = x.reshape(B, 1, K * L) 122 | x = self.input_projection(x) ## First Convolution before fedding the data to the encoders 123 | x = F.relu(x) 124 | x = x.reshape(B, self.channels, K, L) 125 | x = torch.cat([x, time_embed.permute(0, 3, 2, 1), feature_embed.permute(0, 3, 2, 1)], dim=1) 126 | 127 | base_shape = B, x.shape[1], K, L 128 | 129 | xt = self.forward_time(x, base_shape) 130 | xf = self.forward_feature(x, base_shape) 131 | 132 | xt = self.forward_time(xf, base_shape) 133 | xf = self.forward_feature(xt, base_shape) 134 | 135 | 136 | xt_reshaped = xt.reshape(B, base_shape[1], K*L) 137 | xt_proj = F.silu(self.xt_projection(xt_reshaped)) 138 | xt = xt_proj.reshape(B, self.channels, K, L) 139 | 140 | 141 | xf_reshaped = xf.reshape(B, base_shape[1], K*L) 142 | xf_proj = F.silu(self.xf_projection(xf_reshaped)) 143 | xf = xf_proj.reshape(B, self.channels, K, L) 144 | 145 | 146 | x = torch.cat([xt, xf], dim=1) # B, 2*C, K, L 147 | 148 | return x, xt_proj, xf_proj -------------------------------------------------------------------------------- /src/data_loader/elec_tslib_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torch.utils.data import Dataset, DataLoader 4 | 5 | import numpy as np 6 | import torch 7 | import warnings 8 | warnings.filterwarnings('ignore') 9 | from sklearn.preprocessing import StandardScaler 10 | 11 | 12 | '''class StandardScaler(object): 13 | 14 | def __init__(self): 15 | self.mean = 0. 16 | self.std = 1. 17 | 18 | def fit(self, data): 19 | self.mean = data.mean(0) 20 | self.std = data.std(0) 21 | 22 | def transform(self, data): 23 | mean = self.mean 24 | std = self.std 25 | return (data - mean) / std 26 | 27 | def inverse_transform(self, data): 28 | mean = self.mean 29 | std = self.std 30 | return (data * std) + mean''' 31 | 32 | 33 | class Dataset_Custom(Dataset): 34 | 35 | def __init__(self, root_path='dataset/electricity', flag='train', size=None, features='M', data_path='electricity.csv', target='OT', scale=True, 36 | inverse=False, timeenc=1, freq='t', cols=None, percentage=1): 37 | # size [seq_len, label_len, pred_len] 38 | # info 39 | super().__init__() 40 | self.seq_len = size[0] 41 | self.pred_len = size[1] 42 | # init 43 | assert flag in ['train', 'test', 'val'] 44 | type_map = {'train':0, 'val':1, 'test':2} 45 | self.set_type = type_map[flag] 46 | 47 | self.features = features 48 | self.target = target 49 | self.scale = scale 50 | self.inverse = inverse 51 | self.timeenc = timeenc 52 | self.freq = freq 53 | self.percentage = percentage 54 | self.cols=cols 55 | self.root_path = root_path 56 | self.data_path = data_path 57 | self.__read_data__() 58 | 59 | def __read_data__(self): 60 | self.scaler = StandardScaler() 61 | df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path)) 62 | length = len(df_raw)*self.percentage 63 | num_train = int(length*0.7) 64 | num_test = int(length*0.2) 65 | num_vali = int(length*0.1) 66 | # num_vali = len(df_raw) - num_train - num_test 67 | border1s = [0, num_train-self.seq_len, num_train+num_vali-self.seq_len] 68 | border2s = [num_train, num_train+num_vali, num_train+num_vali+num_test] 69 | border1 = border1s[self.set_type] 70 | border2 = border2s[self.set_type] 71 | 72 | if self.features == 'M': 73 | cols_data = df_raw.columns[1:] 74 | df_data = df_raw[cols_data] 75 | elif self.features == 'MS': 76 | cols_data = df_raw.columns[1:] 77 | df_data = df_raw[cols_data] 78 | elif self.features == 'S': 79 | df_data = df_raw[[self.target]] 80 | 81 | if self.scale: 82 | train_data = df_data[border1s[0]:border2s[0]] 83 | self.scaler.fit(train_data.values) 84 | 85 | '''for i in range(len(self.scaler.std)): 86 | if self.scaler.std[i] == 0: 87 | print(i)''' 88 | # print(len(self.scaler.std)) 89 | data = self.scaler.transform(df_data.values) 90 | else: 91 | data = df_data.values 92 | 93 | # df_stamp = df_raw[['date']][border1:border2] 94 | # df_stamp['date'] = pd.to_datetime(df_stamp.date) 95 | df_stamp = pd.date_range(start='4/1/2018',periods=border2-border1, freq='H') 96 | 97 | 98 | self.data_x = data[border1:border2] 99 | if self.inverse: 100 | self.data_y = df_data.values[border1:border2] 101 | else: 102 | self.data_y = data[border1:border2] 103 | 104 | 105 | def __getitem__(self, index): 106 | s_begin = index 107 | s_end = s_begin + self.seq_len 108 | r_begin = s_end 109 | r_end = r_begin + self.pred_len 110 | 111 | seq_x = self.data_x[s_begin:r_end] 112 | 113 | observed_mask = np.ones_like(seq_x) 114 | target_mask=observed_mask.copy() 115 | target_mask[-self.pred_len:] = 0 116 | s = { 117 | 'observed_data': seq_x, 118 | 'observed_mask': observed_mask, 119 | 'gt_mask': target_mask, 120 | 'timepoints': np.arange(self.seq_len+self.pred_len) * 1.0, 121 | 'feature_id': np.arange(seq_x.shape[1]) * 1.0, 122 | } 123 | #return seq_x, seq_y, seq_x_mark, seq_y_mark 124 | return s 125 | 126 | 127 | def __len__(self): 128 | return len(self.data_x) - self.seq_len - self.pred_len + 1 129 | 130 | def inverse_transform(self, data): 131 | return self.scaler.inverse_transform(data) 132 | 133 | 134 | 135 | def get_dataloader_elec(pred_length=96, history_length=192, batch_size=8, device='cuda:0'): 136 | 137 | 138 | train_dataset = Dataset_Custom(flag='train', size=[history_length, pred_length]) 139 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 140 | 141 | valid_dataset = Dataset_Custom(flag='val', size=[history_length, pred_length]) 142 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 143 | 144 | test_dataset = Dataset_Custom(flag='test', size=[history_length, pred_length]) 145 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 146 | print('Lengths', len(train_dataset), len(valid_dataset), len(test_dataset)) 147 | 148 | return train_loader, valid_loader, test_loader, test_dataset 149 | 150 | 151 | def get_dataloader_traffic(pred_length=96, history_length=192, batch_size=8, device='cuda:0'): 152 | 153 | train_dataset = Dataset_Custom(root_path='dataset/traffic', data_path='traffic.csv', flag='train', size=[history_length, pred_length]) 154 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 155 | 156 | valid_dataset = Dataset_Custom(root_path='dataset/traffic', data_path='traffic.csv', flag='val', size=[history_length, pred_length]) 157 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 158 | 159 | test_dataset = Dataset_Custom(root_path='dataset/traffic', data_path='traffic.csv', flag='test', size=[history_length, pred_length]) 160 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 161 | print('Lengths', len(train_dataset), len(valid_dataset), len(test_dataset)) 162 | 163 | return train_loader, valid_loader, test_loader, test_dataset 164 | -------------------------------------------------------------------------------- /src/data_loader/pm25_dataloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from torch.utils.data import DataLoader, Dataset 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class PM25_Dataset(Dataset): 9 | """ 10 | A Dataset class for loading and processing PM2.5 air quality data for use with PyTorch models. 11 | 12 | Parameters: 13 | - eval_length (int): The length of the time series evaluation window (total number of timestamps (L)). 14 | - target_dim (int): The total number of the different features in the multivariate time series (K). 15 | - mode (str): Specifies the mode of the dataset. Can be 'train', 'valid', or 'test'. 16 | - validindex (int): Index used to select the validation month in 'train' and 'valid' modes. 17 | """ 18 | def __init__(self, eval_length=36, target_dim=36, mode="train", validindex=0): 19 | self.eval_length = eval_length 20 | self.target_dim = target_dim 21 | 22 | path = "./data/pm25/pm25_meanstd.pk" 23 | with open(path, "rb") as f: 24 | self.train_mean, self.train_std = pickle.load(f) 25 | if mode == "train": 26 | month_list = [1, 2, 4, 5, 7, 8, 10, 11] 27 | # 1st,4th,7th,10th months are excluded from histmask (since the months are used for creating missing patterns in test dataset) 28 | flag_for_histmask = [0, 1, 0, 1, 0, 1, 0, 1] 29 | month_list.pop(validindex) 30 | flag_for_histmask.pop(validindex) 31 | elif mode == "valid": 32 | month_list = [1, 2, 4, 5, 7, 8, 10, 11] 33 | month_list = month_list[validindex : validindex + 1] 34 | elif mode == "test": 35 | month_list = [3, 6, 9, 12] 36 | self.month_list = month_list 37 | 38 | # create data for batch 39 | self.observed_data = [] # values (separated into each month) 40 | self.observed_mask = [] # masks (separated into each month) 41 | self.gt_mask = [] # ground-truth masks (separated into each month) 42 | self.index_month = [] # indicate month 43 | self.position_in_month = [] # indicate the start position in month (length is the same as index_month) 44 | self.valid_for_histmask = [] # whether the sample is used for histmask 45 | self.use_index = [] # to separate train/valid/test 46 | self.cut_length = [] # excluded from evaluation targets 47 | 48 | df = pd.read_csv( 49 | "./data/pm25/Code/STMVL/SampleData/pm25_ground.txt", 50 | index_col="datetime", 51 | parse_dates=True, 52 | ) 53 | df_gt = pd.read_csv( 54 | "./data/pm25/Code/STMVL/SampleData/pm25_missing.txt", 55 | index_col="datetime", 56 | parse_dates=True, 57 | ) 58 | for i in range(len(month_list)): 59 | current_df = df[df.index.month == month_list[i]] 60 | current_df_gt = df_gt[df_gt.index.month == month_list[i]] 61 | current_length = len(current_df) - eval_length + 1 62 | last_index = len(self.index_month) 63 | self.index_month += np.array([i] * current_length).tolist() 64 | self.position_in_month += np.arange(current_length).tolist() 65 | if mode == "train": 66 | self.valid_for_histmask += np.array( 67 | [flag_for_histmask[i]] * current_length 68 | ).tolist() 69 | 70 | # mask values for observed indices are 1 71 | c_mask = 1 - current_df.isnull().values 72 | c_gt_mask = 1 - current_df_gt.isnull().values 73 | c_data = ( 74 | (current_df.fillna(0).values - self.train_mean) / self.train_std 75 | ) * c_mask 76 | 77 | self.observed_mask.append(c_mask) 78 | self.gt_mask.append(c_gt_mask) 79 | self.observed_data.append(c_data) 80 | 81 | if mode == "test": 82 | n_sample = len(current_df) // eval_length 83 | # interval size is eval_length (missing values are imputed only once) 84 | c_index = np.arange( 85 | last_index, last_index + eval_length * n_sample, eval_length 86 | ) 87 | self.use_index += c_index.tolist() 88 | self.cut_length += [0] * len(c_index) 89 | if len(current_df) % eval_length != 0: # avoid double-count for the last time-series 90 | self.use_index += [len(self.index_month) - 1] 91 | self.cut_length += [eval_length - len(current_df) % eval_length] 92 | 93 | if mode != "test": 94 | self.use_index = np.arange(len(self.index_month)) 95 | self.cut_length = [0] * len(self.use_index) 96 | 97 | # masks for 1st,4th,7th,10th months are used for creating missing patterns in test data, 98 | # so these months are excluded from histmask to avoid leakage 99 | if mode == "train": 100 | ind = -1 101 | self.index_month_histmask = [] 102 | self.position_in_month_histmask = [] 103 | 104 | for i in range(len(self.index_month)): 105 | while True: 106 | ind += 1 107 | if ind == len(self.index_month): 108 | ind = 0 109 | if self.valid_for_histmask[ind] == 1: 110 | self.index_month_histmask.append(self.index_month[ind]) 111 | self.position_in_month_histmask.append( 112 | self.position_in_month[ind] 113 | ) 114 | break 115 | else: # dummy (histmask is only used for training) 116 | self.index_month_histmask = self.index_month 117 | self.position_in_month_histmask = self.position_in_month 118 | 119 | def __getitem__(self, org_index): 120 | """ 121 | Returns a sample from the dataset at the specified index. 122 | 123 | Parameters: 124 | - org_index (int): The original index of the sample to retrieve. 125 | 126 | Returns: 127 | - dict: A dictionary containing the data sample including observed data, masks, and timepoints, and a mask to avoid double evaluation of overlapping values. 128 | """ 129 | index = self.use_index[org_index] 130 | c_month = self.index_month[index] 131 | c_index = self.position_in_month[index] 132 | hist_month = self.index_month_histmask[index] 133 | hist_index = self.position_in_month_histmask[index] 134 | s = { 135 | "observed_data": self.observed_data[c_month][ 136 | c_index : c_index + self.eval_length 137 | ], 138 | "observed_mask": self.observed_mask[c_month][ 139 | c_index : c_index + self.eval_length 140 | ], 141 | "gt_mask": self.gt_mask[c_month][ 142 | c_index : c_index + self.eval_length 143 | ], 144 | "hist_mask": self.observed_mask[hist_month][ 145 | hist_index : hist_index + self.eval_length 146 | ], 147 | "timepoints": np.arange(self.eval_length), 148 | "cut_length": self.cut_length[org_index], 149 | } 150 | 151 | return s 152 | 153 | def __len__(self): 154 | """ 155 | Returns the total number of samples in the dataset. 156 | 157 | Returns: 158 | - int: Total number of samples. 159 | """ 160 | return len(self.use_index) 161 | 162 | 163 | def get_dataloader_pm25(batch_size, device, validindex=0): 164 | """ 165 | Prepares DataLoader objects for the PM2.5 dataset for training, validation, and testing. Also, returns 166 | normalization parameters (scaler and mean_scaler) to normalize the data. 167 | 168 | Parameters: 169 | - batch_size (int): Batch size for the DataLoader. 170 | - device (torch.device): The device on which the tensors will be loaded. 171 | - validindex (int): Index used to select the validation month in 'train' and 'valid' modes. 172 | 173 | Returns: 174 | - Tuple: Contains DataLoader objects for training, validation, and testing, along with normalization parameters (scaler, mean_scaler). 175 | """ 176 | dataset = PM25_Dataset(mode="train", validindex=validindex) 177 | train_loader = DataLoader( 178 | dataset, batch_size=batch_size, num_workers=1, shuffle=True 179 | ) 180 | dataset_test = PM25_Dataset(mode="test", validindex=validindex) 181 | test_loader = DataLoader( 182 | dataset_test, batch_size=batch_size, num_workers=1, shuffle=False 183 | ) 184 | dataset_valid = PM25_Dataset(mode="valid", validindex=validindex) 185 | valid_loader = DataLoader( 186 | dataset_valid, batch_size=batch_size, num_workers=1, shuffle=False 187 | ) 188 | 189 | scaler = torch.from_numpy(dataset.train_std).to(device).float() 190 | mean_scaler = torch.from_numpy(dataset.train_mean).to(device).float() 191 | 192 | return train_loader, valid_loader, test_loader, scaler, mean_scaler 193 | 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

⚠️ This repository has migrated ⚠️

3 |

This repository will not be maintained any further, and issues and pull requests may be ignored. For an up to date codebase, issues, and pull requests, please continue to the new repository.

4 |
5 | 6 |
7 | 8 | ------ 9 | 10 | 11 | 12 | 13 | **Self-Supervised Learning of Time Series Representation via Diffusion Process and Imputation-Interpolation-Forecasting Mask** 14 | 15 |

16 | 17 | GitHub License 18 | 19 | 20 | Paper URL 21 | 22 |

23 | 24 |

25 | Usage • 26 | Examples • 27 | Checkpoints • 28 | Processed Datasets • 29 | Citation 30 |

31 |
32 | 33 | ------ 34 | 35 | :triangular_flag_on_post:**News** (2024.06) We have added scripts to run forecasting experiments on the Electricity dataset provided by [TimesNet](https://github.com/thuml/Time-Series-Library). 36 | 37 | :triangular_flag_on_post:**News** (2024.05) This work is accepted at KDD 2024, to main research track. 38 | 39 | ------ 40 | 41 | 42 | > Time Series Diffusion Embedding (TSDE), bridges the gap of leveraging diffusion models for Time Series Representation Learning (TSRL) as the first diffusion-based SSL TSRL approach. TSDE segments time series data into observed and masked parts using an Imputation-Interpolation-Forecasting (IIF) mask. It applies a trainable embedding function, featuring dual-orthogonal Transformer encoders with a crossover mechanism, to the observed part. We train a reverse diffusion process conditioned on the embeddings, designed to predict noise added to the masked part. Extensive experiments demonstrate TSDE’s superiority in imputation, interpolation, forecasting, anomaly detection, classification, and clustering. 43 | 44 | 45 | TSDE Architecture 46 | 47 | 48 | ## Usage 49 | 50 | We recommend to start with installing dependencies in an virtual environment. 51 | ``` 52 | conda create --name tsde python=3.11 -y 53 | conda activate tsde 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | ### Datasets 58 | Download public datasets used in our experiments: 59 | 60 | ``` 61 | python src/utils/download_data.py [dataset-name] 62 | ``` 63 | Options of [dataset-name]: physio, pm25, electricity, solar, traffic, taxi, wiki, msl, smd, smap, swat and psm 64 | 65 | ### Imputation 66 | 67 | To run the imputation experiments on PhysioNet dataset: 68 | ``` 69 | python src/experiments/train_test_imputation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio [test_missing_ratio] 70 | ``` 71 | In our experiments, we set [test_missing_ratio] to: 0.1, 0.5, and 0.9 72 | 73 | To run the imputation experiments on PM2.5 dataset, first set train-epochs to 1500, and finetuning-epochs to 100 in src/config/base.yaml, and run the following command: 74 | ``` 75 | python src/experiments/train_test_imputation.py --device [device] --dataset Pm25 76 | ``` 77 | 78 | ### Interpolation 79 | To run the imputation experiments on PhysioNet dataset: 80 | ``` 81 | python src/experiments/train_test_interpolation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio [test_missing_ratio] 82 | ``` 83 | In our experiments, we set [test_missing_ratio] to: 0.1, 0.5, and 0.9 84 | 85 | ### Forecasting 86 | Please first set the number of pretraining and finetuning epochs for each dataset in src/config/base_forecasting.yaml, and set the number of features for subsampling training in the TSDE_forecasting model in src/model/main_model.py. 87 | Run the following command: 88 | ``` 89 | python src/experiments/train_test_forecasting.py --dataset [dataset-name] --device [device] --sample_feat 90 | ``` 91 | Please remove the flag --sample_feat to disable the sub-sampling of features during training. 92 | ### Anomaly Detection 93 | Please first set the number of features, the number of pretraining and finetuning epochs for each dataset in src/config/base_ad.yaml. 94 | Run the following command: 95 | ``` 96 | python src/experiments/train_test_anomaly_detection.py --dataset [dataset-name] --device [device] --seed [seed] --anomaly_ratio [anomaly_ratio] 97 | ``` 98 | 99 | The values of [dataset-name], [seed] and [anomaly_ratio] used in our experiments are available in our paper. 100 | 101 | ### Classification on PhysioNet 102 | Run the following command: 103 | ``` 104 | python src/experiments/train_test_classification.py --seed [seed] --device [device] --testmissingratio [test_missing_ratio] 105 | ``` 106 | ### Benchmarking TSDE against Time Series Library models 107 | 1. Upload the Electricity dataset following their guidelines available [here](https://github.com/thuml/Time-Series-Library). The dataset folder should be in the root directory. 108 | 2. Run the following command: 109 | ``` 110 | python src/experiments/train_test_tslib_forecasting.py --device [device] --pred_length [pred_length] --hist_length [hist_length] 111 | ``` 112 | The values of [pred_length], and [hist_length] used in our experiments are available in our paper. 113 | 114 | ## Examples 115 | #### Examples of imputation, interpolation and forecasting 116 | Examples of imputation, interpolation and forecasting 117 | 118 | #### Example of clustering 119 | Example of clustering 120 | 121 | #### Example of embedding visualization 122 | Example of embedding visualization 123 | 124 | ## Checkpoints 125 | To run the evaluation using a specific checkpoint, follow the instructions below. Ensure your environment is set up correctly for running and the datasets are downloaded first. 126 | ### Imputation 127 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Imputation.zip). Place the content of this folder under [root_dir]/save. 128 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'. 129 | ``` 130 | python src/experiments/train_test_imputation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio 0.1 --modelfolder [path_to_checkpoint_folder] --run [run_number] 131 | ``` 132 | ``` 133 | python src/experiments/train_test_imputation.py --device [device] --dataset Pm25 --modelfolder [path_to_checkpoint_folder] --run [run_number] 134 | ``` 135 | 136 | ### Interpolation 137 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Interpolation.zip). Place the content of this folder under [root_dir]/save. 138 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'. 139 | ``` 140 | python src/experiments/train_test_interpolation.py --device [device] --dataset PhysioNet --physionet_classification True --testmissingratio 0.1 --modelfolder [path_to_checkpoint_folder] --run [run_number] 141 | ``` 142 | ### Forecasting 143 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Forecasting.zip). Place the content of this folder under [root_dir]/save. 144 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'. 145 | ``` 146 | python src/experiments/train_test_forecasting.py --device [device] --dataset [dataset-name] --modelfolder [path_to_checkpoint_folder] --run [run_number] 147 | ``` 148 | 149 | ### Anomaly Detection 150 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Anomaly_Detection.zip). Place the content of this folder under [root_dir]/save. 151 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should excludes '[root_dir]/save' and 'model.pth'. 152 | ``` 153 | python src/experiments/train_test_anomaly_detection.py --device [device] --dataset [dataset-name] --modelfolder [path_to_checkpoint_folder] --run [run_number] --disable_finetune 154 | ``` 155 | 156 | ### Classification 157 | 1. **Download the checkpoints**: Access and download the required checkpoints from [here](https://storage.googleapis.com/motherbrain-tsde/Checkpoints/Classification.zip). Place the content of this folder under [root_dir]/save. 158 | 2. **Run the evaluation command** by setting `[path_to_checkpoint_folder]` accordingly. The path should includes '[root_dir]/save' and excludes 'model.pth'. 159 | ``` 160 | python src/experiments/train_test_classification.py --device [device] --modelfolder [path_to_checkpoint_folder] --run [run_number] --disable_finetune 161 | ``` 162 | ## Citation 163 | ``` 164 | @article{senane2024tsde, 165 | title={{Self-Supervised Learning of Time Series Representation via Diffusion Process and Imputation-Interpolation-Forecasting Mask}}, 166 | author={Senane, Zineb and Cao, Lele and Buchner, Valentin Leonhard and Tashiro, Yusuke and You, Lei and Herman, Pawel and Nordahl, Mats and Tu, Ruibo and von Ehrenheim, Vilhelm}, 167 | year={2024}, 168 | eprint={2405.05959}, 169 | archivePrefix={arXiv}, 170 | primaryClass={cs.LG} 171 | } 172 | ``` 173 | -------------------------------------------------------------------------------- /src/utils/download_data.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import zipfile 3 | import sys 4 | import os 5 | import wget 6 | import requests 7 | import pandas as pd 8 | import pickle 9 | import gdown 10 | 11 | # Define a root URL for downloading forecasting datasets used in Salinas et al. https://proceedings.neurips.cc/paper/2019/file/0b105cf1504c4e241fcc6d519ea962fb-Paper.pdf 12 | root = "https://raw.githubusercontent.com/mbohlkeschneider/gluon-ts/mv_release/datasets/" 13 | 14 | 15 | def get_confirm_token(response): 16 | """ 17 | Extract the confirmation token from the response cookies, which is required for some downloads. 18 | 19 | Parameters: 20 | - response: The response object from a requests session. 21 | 22 | Returns: 23 | - The confirmation token if found; otherwise, None. 24 | """ 25 | for key, value in response.cookies.items(): 26 | if key.startswith('download_warning'): 27 | return value 28 | 29 | return None 30 | 31 | def save_response_content(response, destination): 32 | """ 33 | Saves the content of a response object to a file in chunks, which is useful for large downloads. 34 | 35 | Parameters: 36 | - response: The response object from a requests session. 37 | - destination: The file path where the content will be saved. 38 | """ 39 | 40 | CHUNK_SIZE = 32768 41 | 42 | with open(destination, "wb") as f: 43 | for chunk in response.iter_content(CHUNK_SIZE): 44 | if chunk: # filter out keep-alive new chunks 45 | f.write(chunk) 46 | 47 | 48 | def download_file_from_google_drive(url, destination): 49 | """ 50 | Downloads a file from Google Drive by handling the confirmation token and saving the response content. 51 | 52 | Parameters: 53 | - url: The Google Drive file URL. 54 | - destination: The file path where the downloaded data will be saved. 55 | """ 56 | session = requests.Session() 57 | 58 | response = session.get(url, stream = True) 59 | token = get_confirm_token(response) 60 | 61 | if token: 62 | params = { 'confirm' : token } 63 | response = session.get(url, params = params, stream = True) 64 | save_response_content(response, destination) 65 | 66 | # Ensure the data directory exists 67 | os.makedirs("data/", exist_ok=True) 68 | 69 | # Download and extract datasets based on the command-line argument 70 | if sys.argv[1] == "physio": 71 | # Install PhysioNet dataset 72 | 73 | url = "https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz?download" 74 | url_outcomes = "https://physionet.org/files/challenge-2012/1.0.0/Outcomes-a.txt?download" 75 | os.makedirs("data/physio", exist_ok=True) 76 | wget.download(url, out="data/") 77 | wget.download(url_outcomes, out="data/physio") 78 | with tarfile.open("data/set-a.tar.gz", "r:gz") as t: 79 | t.extractall(path="data/physio") 80 | print(f"Downloaded data/physio") 81 | os.remove("data/set-a.tar.gz") 82 | 83 | 84 | elif sys.argv[1] == "pm25": 85 | # Install PM2.5 dataset 86 | 87 | url = "https://www.microsoft.com/en-us/research/uploads/prod/2016/06/STMVL-Release.zip" 88 | 89 | headers = { 90 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' 91 | } 92 | filename = "data/STMVL-Release.zip" 93 | with requests.get(url, stream=True, headers=headers) as response: 94 | # Check if the request was successful (status code 200) 95 | if response.status_code == 200: 96 | with open(filename, 'wb') as f: 97 | # Write the content of the file to the local file 98 | f.write(response.content) 99 | print(f"Downloaded {filename}") 100 | else: 101 | print(f"Failed to download the file. Status code: {response.status_code}") 102 | with zipfile.ZipFile(filename) as z: 103 | z.extractall("data/pm25") 104 | 105 | os.remove(filename) 106 | 107 | def create_normalizer_pm25(): 108 | """ 109 | Normalizes the PM2.5 dataset by calculating and saving the mean and standard deviation, excluding test months. 110 | 111 | 112 | It saves these values for use in normalizing the training and testing data. 113 | """ 114 | df = pd.read_csv( 115 | "./data/pm25/Code/STMVL/SampleData/pm25_ground.txt", 116 | index_col="datetime", # Use the datetime column as the index 117 | parse_dates=True, 118 | ) 119 | 120 | # Define the months to exclude from the normalization process (test months) 121 | test_month = [3, 6, 9, 12] 122 | # Exclude the test months from the dataset 123 | for i in test_month: 124 | df = df[df.index.month != i] 125 | # Calculate the mean and standard deviation for the dataset after excluding the test months 126 | mean = df.describe().loc["mean"].values 127 | std = df.describe().loc["std"].values 128 | 129 | # Save the mean and standard deviation to a file using pickle 130 | path = "./data/pm25/pm25_meanstd.pk" 131 | with open(path, "wb") as f: 132 | pickle.dump([mean, std], f) 133 | create_normalizer_pm25() 134 | 135 | elif sys.argv[1] == "electricity": 136 | # Install Electricity dataset 137 | 138 | url = root + "electricity_nips.tar.gz?download" 139 | wget.download(url, out="data") 140 | with tarfile.open("data/electricity_nips.tar.gz", "r:gz") as t: 141 | t.extractall(path="data/electricity") 142 | print(f"Downloaded data/electricity") 143 | os.remove("data/electricity_nips.tar.gz") 144 | 145 | elif sys.argv[1] == "solar": 146 | # Install Solar dataset 147 | 148 | url = root + "solar_nips.tar.gz?download" 149 | wget.download(url, out="data") 150 | with tarfile.open("data/solar_nips.tar.gz", "r:gz") as t: 151 | t.extractall(path="data/solar") 152 | print(f"Downloaded data/solar") 153 | os.remove("data/solar_nips.tar.gz") 154 | os.rename("data/solar/solar_nips/train/train.json", "data/solar/solar_nips/train/data.json") 155 | os.rename("data/solar/solar_nips/test/test.json", "data/solar/solar_nips/test/data.json") 156 | 157 | elif sys.argv[1] == "traffic": 158 | # Install Traffic dataset 159 | 160 | url = root + "traffic_nips.tar.gz?download" 161 | wget.download(url, out="data") 162 | with tarfile.open("data/traffic_nips.tar.gz", "r:gz") as t: 163 | t.extractall(path="data/traffic") 164 | print(f"Downloaded data/traffic") 165 | os.remove("data/traffic_nips.tar.gz") 166 | 167 | elif sys.argv[1] == "taxi": 168 | # Install Taxi dataset 169 | 170 | url = root + "taxi_30min.tar.gz?download" 171 | wget.download(url, out="data") 172 | with tarfile.open("data/taxi_30min.tar.gz", "r:gz") as t: 173 | t.extractall(path="data/taxi") 174 | print(f"Downloaded data/taxi") 175 | os.remove("data/taxi_30min.tar.gz") 176 | os.rename("data/taxi/taxi_30min/", "data/taxi/taxi_nips") 177 | os.rename("data/taxi/taxi_nips/train/train.json", "data/taxi/taxi_nips/train/data.json") 178 | os.rename("data/taxi/taxi_nips/test/test.json", "data/taxi/taxi_nips/test/data.json") 179 | 180 | 181 | elif sys.argv[1] == "wiki": 182 | # Install Wiki dataset 183 | 184 | url = "https://github.com/awslabs/gluonts/raw/1553651ca1fca63a16e012b8927bd9ce72b8e79e/datasets/wiki-rolling_nips.tar.gz" 185 | wget.download(url, out="data") 186 | with tarfile.open("data/wiki-rolling_nips.tar.gz", "r:gz") as t: 187 | t.extractall(path="data/wiki") 188 | print(f"Downloaded data/wiki") 189 | os.remove("data/wiki-rolling_nips.tar.gz") 190 | os.rename("data/wiki/wiki-rolling_nips/", "data/wiki/wiki_nips") 191 | os.rename("data/wiki/wiki_nips/train/train.json", "data/wiki/wiki_nips/train/data.json") 192 | os.rename("data/wiki/wiki_nips/test/test.json", "data/wiki/wiki_nips/test/data.json") 193 | 194 | elif sys.argv[1] == "msl": 195 | # Install MSL dataset 196 | 197 | url = 'https://drive.google.com/uc?id=14STjpszyi6D0B7BUHZ1L4GLUkhhPXE0G' 198 | gdown.download(url, 'MSL.zip', quiet=False) 199 | with zipfile.ZipFile('MSL.zip') as z: 200 | z.extractall("data/anomaly_detection") 201 | os.remove('MSL.zip') 202 | print(f"Downloaded MSL dataset") 203 | 204 | elif sys.argv[1] == "psm": 205 | # Install PSM dataset 206 | 207 | url = "https://drive.google.com/uc?id=14gCVQRciS2hs2SAjXpqioxE4CUzaYkhb" 208 | gdown.download(url, 'PSM.zip', quiet=False) 209 | with zipfile.ZipFile('PSM.zip') as z: 210 | z.extractall("data/anomaly_detection") 211 | os.remove('PSM.zip') 212 | print(f"Downloaded PSM dataset") 213 | 214 | elif sys.argv[1] == "smap": 215 | # Install SMAP dataset 216 | 217 | url = "https://drive.google.com/uc?id=1kxiTMOouw1p-yJMkb_Q_CGMjakVNtg3X" 218 | gdown.download(url, 'SMAP.zip', quiet=False) 219 | with zipfile.ZipFile('SMAP.zip') as z: 220 | z.extractall("data/anomaly_detection") 221 | os.remove('SMAP.zip') 222 | print(f"Downloaded SMAP dataset") 223 | 224 | 225 | elif sys.argv[1] == "smd": 226 | # Install SMD dataset 227 | 228 | url = "https://drive.google.com/uc?id=1BgjQ7_2uqRrZ789Pijtpid5xpLTniywu" 229 | gdown.download(url, 'SMD.zip', quiet=False) 230 | with zipfile.ZipFile('SMD.zip') as z: 231 | z.extractall("data/anomaly_detection") 232 | os.remove('SMD.zip') 233 | print(f"Downloaded SMD dataset") 234 | 235 | 236 | elif sys.argv[1] == "swat": 237 | # Install SWaT dataset 238 | 239 | url = "https://drive.google.com/uc?id=1eRKQwJhqmUD4LkWnqNy1cdIz3W_y6EtW" 240 | gdown.download(url, 'SWaT.zip', quiet=False) 241 | with zipfile.ZipFile('SWaT.zip') as z: 242 | z.extractall("data/anomaly_detection") 243 | os.remove('SWaT.zip') 244 | print(f"Downloaded SWaT dataset") 245 | 246 | 247 | -------------------------------------------------------------------------------- /src/utils/masking_strategies.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | 6 | 7 | def imputation_mask_batch(observed_mask): 8 | """ 9 | Generates a batch of masks for imputation task where certain observations are randomly masked based on a 10 | sample-specific ratio. 11 | 12 | Parameters: 13 | - observed_mask (Tensor): A tensor indicating observed values (1 for observed, 0 for missing). 14 | 15 | Returns: 16 | - Tensor: A mask tensor for imputation with the same shape as `observed_mask`. 17 | """ 18 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values 19 | rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1) 20 | min_value, max_value = 0.1, 0.9 21 | for i in range(len(observed_mask)): 22 | sample_ratio = min_value + (max_value - min_value)*np.random.rand() # missing ratio ## at random 23 | num_observed = observed_mask[i].sum().item() 24 | num_masked = round(num_observed * sample_ratio) 25 | rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1 26 | cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float() 27 | return cond_mask 28 | 29 | def interpolation_mask_batch(observed_mask): 30 | """ 31 | Generates a batch of masks for interpolation task by randomly selecting timestamps to mask across all features. 32 | 33 | Parameters: 34 | - observed_mask (Tensor): A tensor indicating observed values. 35 | 36 | Returns: 37 | - Tensor: A mask tensor for interpolation tasks. 38 | """ 39 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values 40 | total_timestamps = observed_mask.shape[2] 41 | timestamps = np.arange(total_timestamps) 42 | for i in range(len(observed_mask)): 43 | mask_timestamp = np.random.choice( 44 | timestamps 45 | ) 46 | rand_for_mask[i][:,mask_timestamp] = -1 47 | cond_mask = (rand_for_mask > 0).float() 48 | return cond_mask 49 | 50 | 51 | def forecasting_mask_batch(observed_mask): 52 | """ 53 | Generates a batch of masks for forecasting task by masking out all future values beyond a randomly selected start 54 | point in the sequence (30% timestamps at most). 55 | 56 | Parameters: 57 | - observed_mask (Tensor): A tensor indicating observed values. 58 | 59 | Returns: 60 | - Tensor: A mask tensor for forecasting tasks. 61 | """ 62 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values 63 | total_timestamps = observed_mask.shape[2] 64 | start_pred_timestamps = round(total_timestamps/3) 65 | timestamps = np.arange(total_timestamps)[start_pred_timestamps:] 66 | for i in range(len(observed_mask)): 67 | start_forecast_mask = np.random.choice( 68 | timestamps 69 | ) 70 | rand_for_mask[i][:,-start_forecast_mask:] = -1 71 | cond_mask = (rand_for_mask > 0).float() 72 | return cond_mask 73 | 74 | 75 | def forecasting_imputation_mask_batch(observed_mask): 76 | """ 77 | Generates a batch of masks for forecasting/imputation task by masking out all 78 | future values for a random subset of features beyond a randomly selected start 79 | point in the sequence (30% timestamps at most). 80 | 81 | Parameters: 82 | - observed_mask (Tensor): A tensor indicating observed values. 83 | 84 | Returns: 85 | - Tensor: A mask tensor for forecasting tasks. 86 | """ 87 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values 88 | total_timestamps = observed_mask.shape[2] 89 | start_pred_timestamps = round(total_timestamps/3) 90 | timestamps = np.arange(total_timestamps)[start_pred_timestamps:] 91 | 92 | for i in range(len(observed_mask)): 93 | batch_indices = list(np.arange(0, len(rand_for_mask[i]))) 94 | n_keep_dims = random.choice([1, 2, 3]) # pick how many dims to keep unmasked 95 | keep_dims_idx = random.sample(batch_indices, n_keep_dims) # choose the dims to keep 96 | mask_dims_idx = [i for i in batch_indices if i not in keep_dims_idx] # choose the dims to mask 97 | start_forecast_mask = np.random.choice( 98 | timestamps 99 | ) 100 | rand_for_mask[i][mask_dims_idx, -start_forecast_mask:] = -1 101 | cond_mask = (rand_for_mask > 0).float() 102 | return cond_mask 103 | 104 | 105 | def imputation_mask_sample(observed_mask): 106 | """ 107 | Generates a mask for imputation for a single sample, similar to `imputation_mask_batch` but for an individual sample. 108 | 109 | Parameters: 110 | - observed_mask (Tensor): A tensor indicating observed values for a single sample. 111 | 112 | Returns: 113 | - Tensor: A mask tensor for imputation for the sample. 114 | """ 115 | ## Observed mask of shape KxL 116 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values 117 | 118 | rand_for_mask = rand_for_mask.reshape(-1) 119 | min_value, max_value = 0.1, 0.9 120 | sample_ratio = min_value + (max_value - min_value)*np.random.rand() 121 | num_observed = observed_mask.sum().item() 122 | num_masked = round(num_observed * sample_ratio) 123 | rand_for_mask[rand_for_mask.topk(num_masked).indices] = -1 124 | cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float() 125 | return cond_mask 126 | 127 | def interpolation_mask_sample(observed_mask): 128 | """ 129 | Generates a mask for interpolation for a single sample by randomly selecting a timestamp to mask. 130 | 131 | Parameters: 132 | - observed_mask (Tensor): A tensor indicating observed values for a single sample. 133 | 134 | Returns: 135 | - Tensor: A mask tensor for interpolation for the sample. 136 | """ 137 | ## Observed mask of shape KxL 138 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values 139 | total_timestamps = observed_mask.shape[1] 140 | timestamps = np.arange(total_timestamps) 141 | 142 | mask_timestamp = np.random.choice( 143 | timestamps 144 | ) 145 | rand_for_mask[:,mask_timestamp] = -1 146 | cond_mask = (rand_for_mask > 0).float() 147 | return cond_mask 148 | 149 | 150 | def forecasting_mask_sample(observed_mask): 151 | """ 152 | Generates a mask for forecasting for a single sample by masking out all future values beyond a selected timestamp. 153 | 154 | Parameters: 155 | - observed_mask (Tensor): A tensor indicating observed values for a single sample. 156 | 157 | Returns: 158 | - Tensor: A mask tensor for forecasting for the sample. 159 | """ 160 | ## Observed mask of shape KxL 161 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask ### array like observed mask filled with random values 162 | total_timestamps = observed_mask.shape[1] 163 | 164 | start_pred_timestamps = round(total_timestamps/3) 165 | timestamps = np.arange(total_timestamps)[-start_pred_timestamps:] 166 | 167 | start_forecast_mask = np.random.choice( 168 | timestamps 169 | ) 170 | rand_for_mask[:,start_forecast_mask:] = -1 171 | cond_mask = (rand_for_mask > 0).float() 172 | 173 | return cond_mask 174 | 175 | 176 | 177 | def get_mask_equal_p_sample(observed_mask): 178 | """ 179 | IIF mix masking strategy. 180 | Generates masks for a batch of samples where each sample has an equal probability of being assigned one of the 181 | three mask types: imputation, interpolation, or forecasting. 182 | 183 | Parameters: 184 | - observed_mask (Tensor): A tensor indicating observed values for a batch of samples. 185 | 186 | Returns: 187 | - Tensor: A batch of masks with a mix of the three types. 188 | """ 189 | B, K, L = observed_mask.shape 190 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask 191 | for i in range(B): 192 | 193 | threshold = 1/3 194 | 195 | imp_mask = imputation_mask_sample(observed_mask[i]) 196 | p = np.random.rand() # missing probability at random 197 | 198 | if p0.5: 241 | if mask_type == 'imputation': 242 | mask = imputation_mask_sample(mask) 243 | elif mask_type == 'interpolation': 244 | mask = interpolation_mask_sample(mask) 245 | else: 246 | mask = forecasting_mask_sample(mask) 247 | 248 | m = torch.sum(torch.eq(mask, 0))/(K*L) 249 | 250 | rand_for_mask[i]=mask 251 | 252 | return rand_for_mask 253 | 254 | def pattern_mask_batch(observed_mask): 255 | """ 256 | Generates a batch of masks based on a predetermined pattern or a random choice between imputation mask and a 257 | previously used mask pattern. Used for finetuning TSDE on PM25 dataset. 258 | 259 | Parameters: 260 | - observed_mask (Tensor): A tensor indicating observed values for a batch of samples. 261 | 262 | Returns: 263 | - Tensor: A batch of masks where each mask is either an imputation mask or follows a specific pattern. 264 | """ 265 | pattern_mask = observed_mask 266 | rand_mask = imputation_mask_batch(observed_mask) 267 | 268 | cond_mask = observed_mask.clone() ### Gradients can flow back to observed_mask 269 | for i in range(len(cond_mask)): 270 | mask_choice = np.random.rand() 271 | if mask_choice > 0.5: 272 | cond_mask[i] = rand_mask[i] 273 | else: # draw another sample for histmask (i-1 corresponds to another sample) ###### Not randomly sampled? 274 | cond_mask[i] = cond_mask[i] * pattern_mask[i - 1] 275 | return cond_mask 276 | 277 | 278 | -------------------------------------------------------------------------------- /src/data_loader/physio_dataloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import re 4 | import numpy as np 5 | import pandas as pd 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | # 35 attributes which contains enough non-values 9 | attributes = ['DiasABP', 'HR', 'Na', 'Lactate', 'NIDiasABP', 'PaO2', 'WBC', 'pH', 'Albumin', 'ALT', 'Glucose', 'SaO2', 10 | 'Temp', 'AST', 'Bilirubin', 'HCO3', 'BUN', 'RespRate', 'Mg', 'HCT', 'SysABP', 'FiO2', 'K', 'GCS', 11 | 'Cholesterol', 'NISysABP', 'TroponinT', 'MAP', 'TroponinI', 'PaCO2', 'Platelets', 'Urine', 'NIMAP', 12 | 'Creatinine', 'ALP'] 13 | 14 | 15 | def extract_hour(x): 16 | """ 17 | Extracts and returns the hour from a time string. Process PhysioNet data at an hourly granularity. 18 | 19 | Parameters: 20 | - x (str): Time string in the format 'HH:MM'. 21 | 22 | Returns: 23 | - int: The hour part extracted from the time string. 24 | 25 | """ 26 | h, _ = map(int, x.split(":")) 27 | return h 28 | 29 | 30 | def parse_data(x): 31 | """ 32 | Processes a pandas DataFrame to extract the recorded value for each attribute in 'attributes'. 33 | Returns a list of values with NaN for any missing attribute. 34 | 35 | Parameters: 36 | - x (DataFrame): DataFrame containing time series data for various attributes. 37 | 38 | Returns: 39 | - list: A list of observed values for the specified attributes, with NaN for missing ones. 40 | """ 41 | 42 | # extract the last value for each attribute 43 | x = x.set_index("Parameter").to_dict()["Value"] 44 | 45 | values = [] 46 | 47 | for attr in attributes: 48 | if x.__contains__(attr): 49 | values.append(x[attr]) 50 | else: 51 | values.append(np.nan) 52 | return values 53 | 54 | def extract_record_id_and_death_status_as_dict(file_path): 55 | """ 56 | Creates a dictionary mapping patient RecordID to their in-hospital death status from a CSV file. 57 | 58 | Parameters: 59 | - file_path (str): Path to the CSV file containing RecordID and In-hospital_death columns. 60 | 61 | Returns: 62 | - dict: A dictionary with RecordID as keys and in-hospital death status as values. 63 | """ 64 | 65 | # Read the data from the file 66 | df = pd.read_csv(file_path) 67 | 68 | # Extract RecordID and In-hospital_death columns and convert to a dictionary 69 | result_dict = df.set_index('RecordID')['In-hospital_death'].to_dict() 70 | 71 | # Print or return the result as needed 72 | return result_dict 73 | 74 | 75 | # File path (change this to your file path) 76 | file_path = './data/physio/Outcomes-a.txt' 77 | 78 | # Run the function with the file path and print the results 79 | id_label_mapping = extract_record_id_and_death_status_as_dict(file_path) 80 | 81 | 82 | def parse_id(id_, missing_ratio=0.1, mode='imputation'): 83 | """ 84 | Reads and preprocesses patient data, applies missing data handling, and prepares it for model input. 85 | 86 | Parameters: 87 | - id_ (str): The patient ID to process. 88 | - missing_ratio (float): The ratio of data to be considered as missing, masked and used for evaluation. 89 | - mode (str): The mode of handling missing data, either 'imputation' or 'interpolation'. 90 | 91 | Returns: 92 | - Tuple containing processed data arrays and labels: (observed_values, observed_masks, gt_masks, label) 93 | """ 94 | 95 | data = pd.read_csv("./data/physio/set-a/{}.txt".format(id_)) 96 | # set hour 97 | data["Time"] = data["Time"].apply(lambda x: extract_hour(x)) 98 | 99 | # create data for 48 hours x 35 attributes 100 | observed_values = [] 101 | for h in range(48): 102 | observed_values.append(parse_data(data[data["Time"] == h])) 103 | observed_values = np.array(observed_values) 104 | observed_masks = ~np.isnan(observed_values) 105 | if mode == 'imputation': 106 | # randomly set some percentage as ground-truth 107 | masks = observed_masks.reshape(-1).copy() 108 | obs_indices = np.where(masks)[0].tolist() 109 | miss_indices = np.random.choice( 110 | obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False 111 | ) 112 | masks[miss_indices] = False 113 | gt_masks = masks.reshape(observed_masks.shape) 114 | elif mode == 'interpolation': 115 | # randomly set some percentage of timestamps values as ground-truth 116 | timestamps = np.arange(48) 117 | miss_timestamps = np.random.choice( 118 | timestamps, (int)(len(timestamps) * missing_ratio), replace=False 119 | ) 120 | masks = observed_masks.copy() 121 | masks[miss_timestamps,:] = False 122 | gt_masks = masks 123 | else: 124 | print('Please set mode to interpolation or imputation!') 125 | 126 | observed_values = np.nan_to_num(observed_values) 127 | observed_masks = observed_masks.astype("float32") 128 | gt_masks = gt_masks.astype("float32") 129 | label = id_label_mapping[int(id_)] 130 | return observed_values, observed_masks, gt_masks, label 131 | 132 | 133 | def get_idlist(): 134 | """ 135 | Scans a directory for files matching the patient ID pattern and returns a sorted list of IDs. 136 | 137 | Returns: 138 | - list: A sorted list of patient IDs extracted from filenames. 139 | """ 140 | patient_id = [] 141 | for filename in os.listdir("./data/physio/set-a"): 142 | match = re.search("\d{6}", filename) 143 | if match: 144 | patient_id.append(match.group()) 145 | patient_id = np.sort(patient_id) 146 | return patient_id 147 | 148 | 149 | class Physio_Dataset(Dataset): 150 | """ 151 | A Dataset class for loading and processing PhysioNet data for use in PyTorch. 152 | 153 | Parameters: 154 | - eval_length (int): The length of the time series evaluation window (Total number of timestamps in each MTS(L)). 155 | - use_index_list (list, optional): List of indices to use from the dataset (to differentiate between training, validation and test sets). 156 | - missing_ratio (float): Ratio of data to mask as missing ansd use for evaluation (set to 0.1, 0.5 or 0.9). 157 | - seed (int): Random seed for reproducibility (used to randomly select the same subset of values and mask them). 158 | - mode (str): Mode for handling missing data ('imputation' or 'interpolation'). 159 | """ 160 | def __init__(self, eval_length=48, use_index_list=None, missing_ratio=0.0, seed=0, mode='imputation'): 161 | self.eval_length = eval_length 162 | np.random.seed(seed) # seed for ground truth choice 163 | 164 | self.observed_values = [] 165 | self.observed_masks = [] 166 | self.gt_masks = [] 167 | self.labels = [] 168 | path = ( 169 | "./data/physio_missing" + str(missing_ratio) + "_" + mode + "_seed" + str(seed) + ".pk" 170 | ) 171 | 172 | if os.path.isfile(path) == False: # if datasetfile is none, create 173 | idlist = get_idlist() 174 | for id_ in idlist: 175 | try: 176 | observed_values, observed_masks, gt_masks, label = parse_id( 177 | id_, missing_ratio, mode=mode 178 | ) 179 | self.observed_values.append(observed_values) 180 | self.observed_masks.append(observed_masks) 181 | self.gt_masks.append(gt_masks) 182 | self.labels.append(label) 183 | except Exception as e: 184 | print(id_, e) 185 | continue 186 | self.observed_values = np.array(self.observed_values) 187 | self.observed_masks = np.array(self.observed_masks) 188 | self.gt_masks = np.array(self.gt_masks) 189 | 190 | # calc mean and std and normalize values 191 | # (it is the same normalization as Cao et al. (2018) (https://github.com/caow13/BRITS)) 192 | tmp_values = self.observed_values.reshape(-1, 35) 193 | tmp_masks = self.observed_masks.reshape(-1, 35) 194 | mean = np.zeros(35) 195 | std = np.zeros(35) 196 | for k in range(35): 197 | c_data = tmp_values[:, k][tmp_masks[:, k] == 1] 198 | mean[k] = c_data.mean() 199 | std[k] = c_data.std() 200 | self.observed_values = ( 201 | (self.observed_values - mean) / std * self.observed_masks 202 | ) 203 | 204 | with open(path, "wb") as f: 205 | pickle.dump( 206 | [self.observed_values, self.observed_masks, self.gt_masks, self.labels], f 207 | ) 208 | else: # load datasetfile 209 | with open(path, "rb") as f: 210 | self.observed_values, self.observed_masks, self.gt_masks, self.labels = pickle.load( 211 | f 212 | ) 213 | if use_index_list is None: 214 | self.use_index_list = np.arange(len(self.observed_values)) 215 | else: 216 | self.use_index_list = use_index_list 217 | 218 | def __getitem__(self, org_index): 219 | """ 220 | Returns a sample from the dataset at the specified index. 221 | 222 | Parameters: 223 | - org_index (int): The index of the sample to retrieve. 224 | 225 | Returns: 226 | - dict: A dictionary containing the data sample. 227 | """ 228 | index = self.use_index_list[org_index] 229 | s = { 230 | "observed_data": self.observed_values[index], 231 | "observed_mask": self.observed_masks[index], 232 | "gt_mask": self.gt_masks[index], 233 | "timepoints": np.arange(self.eval_length), 234 | "labels": self.labels[index], 235 | } 236 | return s 237 | 238 | def __len__(self): 239 | """ 240 | Returns the total number of samples in the dataset. 241 | 242 | Returns: 243 | - int: Total number of samples. 244 | """ 245 | return len(self.use_index_list) 246 | 247 | 248 | def get_dataloader_physio(seed=1, nfold=None, batch_size=16, missing_ratio=0.1, mode='imputation'): 249 | """ 250 | Prepares DataLoader objects for the PhysioNet dataset for training, validation, and testing. 251 | 252 | Parameters: 253 | - seed (int): Random seed for reproducibility. 254 | - nfold (int, optional): Current fold number for cross-validation. 255 | - batch_size (int): Batch size for the DataLoader. 256 | - missing_ratio (float): Ratio of data to mask as missing. 257 | - mode (str): Mode for handling missing data ('imputation' or 'interpolation'). 258 | 259 | Returns: 260 | - Tuple: Contains DataLoader objects for training, validation, and testing. 261 | """ 262 | # only to obtain total length of dataset 263 | dataset = Physio_Dataset(missing_ratio=missing_ratio, seed=seed, mode=mode) 264 | indlist = np.arange(len(dataset)) 265 | 266 | np.random.seed(seed) 267 | np.random.shuffle(indlist) 268 | 269 | # 5-fold test 270 | start = (int)(nfold * 0.2 * len(dataset)) 271 | end = (int)((nfold + 1) * 0.2 * len(dataset)) 272 | test_index = indlist[start:end] 273 | remain_index = np.delete(indlist, np.arange(start, end)) 274 | 275 | np.random.seed(seed) 276 | np.random.shuffle(remain_index) 277 | num_train = (int)(len(dataset) * 0.7) 278 | train_index = remain_index[:num_train] 279 | valid_index = remain_index[num_train:] 280 | 281 | dataset = Physio_Dataset( 282 | use_index_list=train_index, missing_ratio=missing_ratio, seed=seed, mode=mode 283 | ) 284 | train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=1) 285 | valid_dataset = Physio_Dataset( 286 | use_index_list=valid_index, missing_ratio=missing_ratio, seed=seed, mode=mode 287 | ) 288 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=0) 289 | test_dataset = Physio_Dataset( 290 | use_index_list=test_index, missing_ratio=missing_ratio, seed=seed, mode=mode 291 | ) 292 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=0) 293 | return train_loader, valid_loader, test_loader 294 | 295 | 296 | 297 | 298 | class Imputed_Physio_Dataset(Dataset): 299 | """ 300 | A Dataset class for loading imputed PhysioNet data in PyTorch (used for classification and clustering experiments). 301 | 302 | Parameters: 303 | - filename (str): Path to the file containing the imputed time series. 304 | - flag (str): Indicates the subset of the data ('Train', 'Val', or 'Test'). 305 | """ 306 | def __init__(self, filename, flag='Train'): 307 | 308 | self.file = filename 309 | self.flag = flag 310 | data_path = self.file + f'generated_outputs_{flag}_nsample100.pk' 311 | 312 | with open(data_path, 'rb') as f: 313 | self.samples, self.all_target, self.all_evalpoint, self.all_observed, self.all_observed_time, self.labels, self.scaler, self.mean_scaler = pickle.load(f) 314 | 315 | self.imputed_mts = self.samples.median(dim=1)[0]*(1-self.all_observed)+self.all_target*self.all_observed 316 | 317 | 318 | def __len__(self): 319 | """ 320 | Returns the total number of samples in the dataset. 321 | 322 | Returns: 323 | - int: Total number of samples. 324 | """ 325 | return len(self.imputed_mts) 326 | 327 | def __getitem__(self, index): 328 | """ 329 | Returns a sample from the dataset at the specified index. 330 | 331 | Parameters: 332 | - index (int): The index of the sample to retrieve. 333 | 334 | Returns: 335 | - dict: A dictionary containing the data sample. 336 | """ 337 | s = { 338 | "observed_data": self.imputed_mts[index].cpu(), 339 | "observed_mask": np.ones_like(self.imputed_mts[index].cpu()), 340 | "gt_mask": np.ones_like(self.imputed_mts[index].cpu()), 341 | "timepoints": np.arange(self.imputed_mts[index].shape[0]), 342 | "y": self.labels[index], 343 | "labels": self.labels[index], 344 | } 345 | return s 346 | 347 | 348 | def get_physio_dataloader_for_classification(filename, batch_size): 349 | """ 350 | Prepares DataLoader objects for classification and clustering experiments using imputed PhysioNet data. 351 | 352 | Parameters: 353 | - filename (str): Path to the file containing the imputed dataset. 354 | - batch_size (int): Batch size for the DataLoader. 355 | 356 | Returns: 357 | - Tuple: Contains DataLoader objects for training, validation, and testing. 358 | """ 359 | train_dataset = Imputed_Physio_Dataset( 360 | filename, flag = 'Train' 361 | ) 362 | valid_dataset = Imputed_Physio_Dataset( 363 | filename, flag = 'Val' 364 | ) 365 | test_dataset = Imputed_Physio_Dataset( 366 | filename, flag = 'Test' 367 | ) 368 | 369 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=1) 370 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=0) 371 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=0) 372 | 373 | return train_loader, valid_loader, test_loader 374 | -------------------------------------------------------------------------------- /src/data_loader/forecasting_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import datetime 4 | import json 5 | import pickle 6 | import torch 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | 10 | def preprocess_dataset(dataset_name, train_length, test_length, skip_length, history_length): 11 | """ 12 | Preprocesses and saves a specified dataset for forecasting task, including training, testing, and validation sets. 13 | This function saves the corresponding processed splits (train and test), and the mean and standard deviation for the training set used for normalization. 14 | 15 | 16 | Parameters: 17 | - dataset_name (str): The name of the dataset. It should be one of the following: electricity, solar, taxi, traffic or wiki. 18 | - train_length (int): The length of the training sequences. It refers to the total number of timestamps in training and validation sets. 19 | - test_length (int): The length of the testing sequences. It refers to the total number of timestamps in the test set. 20 | - skip_length (int): The number of sequences to skip between training and testing. The number of timestamps to skip in the test set (we are evaluating only on a subset). 21 | - history_length (int): The length of the historical data to consider. The total number of timestamps to use as history window in every MTS. 22 | 23 | 24 | """ 25 | 26 | path_train = f'./data/{dataset_name}/{dataset_name}_nips/train/data.json' #train 27 | path_test = f'./data/{dataset_name}/{dataset_name}_nips/test/data.json' #test 28 | 29 | 30 | main_data=[] 31 | mask_data=[] 32 | 33 | hour_data=None 34 | 35 | 36 | with open(path_train, 'r') as file: 37 | data_train = [json.loads(line) for line in file] 38 | 39 | with open(path_test, 'r') as file: 40 | data_test = [json.loads(line) for line in file] 41 | 42 | ## Prepare Train Sequences 43 | for obj in data_train: 44 | tmp_data = np.array(obj['target']) 45 | tmp_mask = np.ones_like(tmp_data) 46 | 47 | if len(tmp_data) == train_length and hour_data is None: 48 | c_time = datetime.datetime.strptime(obj['start'],'%Y-%m-%d %H:%M:%S') 49 | hour_data = [] 50 | day_data = [] 51 | time_data = [] 52 | for k in range(train_length): 53 | time_data.append(c_time) 54 | hour_data.append(c_time.hour) 55 | day_data.append(c_time.weekday()) 56 | c_time = c_time + datetime.timedelta(hours=1) 57 | else: 58 | if len(tmp_data) != train_length: #fill NA by 0 59 | tmp_padding = np.zeros(train_length-len(tmp_data)) 60 | tmp_data = np.concatenate([tmp_padding,tmp_data]) 61 | tmp_mask = np.concatenate([tmp_padding,tmp_mask]) 62 | 63 | 64 | main_data.append(tmp_data) 65 | mask_data.append(tmp_mask) 66 | 67 | ## Prepare Test Sequences 68 | if dataset_name != "solar": 69 | cnt = 0 70 | ind = 0 71 | 72 | for line in data_test: 73 | cnt+=1 74 | if cnt <=skip_length: 75 | continue 76 | tmp_data = np.array(line['target']) 77 | tmp_data = tmp_data[-test_length-history_length:] 78 | 79 | tmp_mask = np.ones_like(tmp_data) 80 | 81 | main_data[ind] = np.concatenate([main_data[ind],tmp_data]) 82 | mask_data[ind] = np.concatenate([mask_data[ind],tmp_mask]) 83 | c_time = datetime.datetime.strptime(obj['start'],'%Y-%m-%d %H:%M:%S') 84 | for i in range(test_length+history_length): 85 | time_data.append(c_time) 86 | hour_data.append(c_time.hour) 87 | day_data.append(c_time.weekday()) 88 | c_time = c_time + datetime.timedelta(hours=1) 89 | ind += 1 90 | main_data = np.stack(main_data,-1) 91 | mask_data = np.stack(mask_data,-1) 92 | print('Main data shape', main_data.shape) 93 | ## Save means 94 | mean_data = main_data[:-test_length-history_length].mean(0) 95 | std_data = main_data[:-test_length-history_length].std(0) 96 | 97 | 98 | ## Save means 99 | paths=f'./data/{dataset_name}/{dataset_name}_nips/meanstd.pkl' 100 | if os.path.isfile(paths) == False: 101 | with open(paths, 'wb') as f: 102 | pickle.dump([mean_data,std_data],f) 103 | 104 | ## Save sequences 105 | paths=f'./data/{dataset_name}/{dataset_name}_nips/data.pkl' 106 | if os.path.isfile(paths) == False: 107 | with open(paths, 'wb') as f: 108 | pickle.dump([main_data,mask_data],f) 109 | 110 | 111 | def preprocess_taxi(train_length, test_length, skip_length, history_length): 112 | """ 113 | Specialized preprocessing function for the taxi dataset, similar to preprocess_dataset but tailored to its structure. 114 | 115 | Parameters: 116 | - train_length (int): The length of the training sequences. It refers to the total number of timestamps in training and validation sets. 117 | - test_length (int): The length of the testing sequences. It refers to the total number of timestamps in the test set. 118 | - skip_length (int): The number of sequences to skip between training and testing. The number of timestamps to skip in the test set (we are evaluating only on a subset). 119 | - history_length (int): The length of the historical data to consider. The total number of timestamps to use as history window in every MTS. 120 | """ 121 | 122 | path_train = f'./data/taxi/taxi_nips/train/data.json' #train 123 | path_test = f'./data/taxi/taxi_nips/test/data.json' #test 124 | 125 | 126 | main_data=[] 127 | mask_data=[] 128 | 129 | hour_data=None 130 | 131 | 132 | with open(path_train, 'r') as file: 133 | data_train = [json.loads(line) for line in file] 134 | 135 | with open(path_test, 'r') as file: 136 | data_test = [json.loads(line) for line in file] 137 | 138 | ## Prepare Train Sequences 139 | for obj in data_train: 140 | tmp_data = np.array(obj['target']) 141 | tmp_mask = np.ones_like(tmp_data) 142 | 143 | if len(tmp_data) == train_length and hour_data is None: 144 | c_time = datetime.datetime.strptime(obj['start'],'%Y-%m-%d %H:%M:%S') 145 | hour_data = [] 146 | day_data = [] 147 | time_data = [] 148 | for k in range(train_length): 149 | time_data.append(c_time) 150 | hour_data.append(int(c_time.hour+c_time.minute/30)) 151 | day_data.append(c_time.weekday()) 152 | c_time = c_time + datetime.timedelta(minutes=30) 153 | else: 154 | if len(tmp_data) != train_length: #fill NA by 0 155 | tmp_padding = np.zeros(train_length-len(tmp_data)) 156 | tmp_data = np.concatenate([tmp_padding,tmp_data]) 157 | tmp_mask = np.concatenate([tmp_padding,tmp_mask]) 158 | 159 | 160 | main_data.append(tmp_data) 161 | mask_data.append(tmp_mask) 162 | 163 | ## Prepare Test Sequences 164 | cnt = 0 165 | ind = 0 166 | 167 | for line in data_test: 168 | cnt+=1 169 | if cnt <=skip_length: 170 | continue 171 | tmp_data = np.array(line['target']) 172 | tmp_data = tmp_data[-test_length-history_length:] 173 | 174 | tmp_mask = np.ones_like(tmp_data) 175 | 176 | main_data[ind] = np.concatenate([main_data[ind],tmp_data]) 177 | mask_data[ind] = np.concatenate([mask_data[ind],tmp_mask]) 178 | c_time = datetime.datetime.strptime(obj['start'],'%Y-%m-%d %H:%M:%S') 179 | for i in range(test_length+history_length): 180 | time_data.append(c_time) 181 | hour_data.append(c_time.hour) 182 | day_data.append(c_time.weekday()) 183 | c_time = c_time + datetime.timedelta(minutes=30) 184 | ind += 1 185 | 186 | main_data = np.stack(main_data,-1) 187 | mask_data = np.stack(mask_data,-1) 188 | 189 | 190 | ## Save mean 191 | mean_data = main_data[:-test_length-history_length].mean(0) 192 | std_data = main_data[:-test_length-history_length].std(0) 193 | 194 | ## Save means 195 | paths=f'./data/taxi/taxi_nips/meanstd.pkl' 196 | if os.path.isfile(paths) == False: 197 | with open(paths, 'wb') as f: 198 | pickle.dump([mean_data,std_data],f) 199 | 200 | ## Save sequences 201 | paths=f'./data/taxi/taxi_nips/data.pkl' 202 | if os.path.isfile(paths) == False: 203 | with open(paths, 'wb') as f: 204 | pickle.dump([main_data,mask_data],f) 205 | 206 | def preprocess_wiki(train_length, test_length, skip_length, history_length): 207 | """ 208 | Specialized preprocessing function for the wiki dataset, similar to preprocess_dataset but tailored to its structure. 209 | 210 | Parameters: 211 | - train_length (int): The length of the training sequences. It refers to the total number of timestamps in training and validation sets. 212 | - test_length (int): The length of the testing sequences. It refers to the total number of timestamps in the test set. 213 | - skip_length (int): The number of sequences to skip between training and testing. The number of timestamps to skip in the test set (we are evaluating only on a subset). 214 | - history_length (int): The length of the historical data to consider. The total number of timestamps to use as history window in every MTS. 215 | """ 216 | path_train = f'./data/wiki/wiki_nips/train/data.json' #train 217 | path_test = f'./data/wiki/wiki_nips/test/data.json' #test 218 | 219 | 220 | main_data=[] 221 | mask_data=[] 222 | 223 | hour_data=None 224 | 225 | 226 | with open(path_train, 'r') as file: 227 | data_train = [json.loads(line) for line in file] 228 | 229 | with open(path_test, 'r') as file: 230 | data_test = [json.loads(line) for line in file] 231 | 232 | ## Prepare Train Sequences 233 | for obj in data_train: 234 | tmp_data = np.array(obj['target']) 235 | tmp_mask = np.ones_like(tmp_data) 236 | 237 | if len(tmp_data) == train_length and hour_data is None: 238 | c_time = datetime.datetime.strptime(obj['start'],'%Y-%m-%d %H:%M:%S') 239 | hour_data = [] 240 | day_data = [] 241 | time_data = [] 242 | for k in range(train_length): 243 | time_data.append(c_time) 244 | hour_data.append(c_time.hour) 245 | day_data.append(c_time.weekday()) 246 | c_time = c_time + datetime.timedelta(days=1) 247 | else: 248 | if len(tmp_data) != train_length: #fill NA by 0 249 | tmp_padding = np.zeros(train_length-len(tmp_data)) 250 | tmp_data = np.concatenate([tmp_padding,tmp_data]) 251 | tmp_mask = np.concatenate([tmp_padding,tmp_mask]) 252 | 253 | 254 | main_data.append(tmp_data) 255 | mask_data.append(tmp_mask) 256 | 257 | ## Prepare Test Sequences 258 | cnt = 0 259 | ind = 0 260 | 261 | for line in data_test: 262 | cnt+=1 263 | if cnt <=skip_length: 264 | continue 265 | tmp_data = np.array(line['target']) 266 | tmp_data = tmp_data[-test_length-history_length:] 267 | 268 | tmp_mask = np.ones_like(tmp_data) 269 | 270 | main_data[ind] = np.concatenate([main_data[ind],tmp_data]) 271 | mask_data[ind] = np.concatenate([mask_data[ind],tmp_mask]) 272 | c_time = datetime.datetime.strptime(obj['start'],'%Y-%m-%d %H:%M:%S') 273 | 274 | ind += 1 275 | main_data = np.stack(main_data[-2000:],-1) 276 | mask_data = np.stack(mask_data[-2000:],-1) 277 | 278 | mean_data = main_data[:-test_length-history_length].mean(0) 279 | std_data = main_data[:-test_length-history_length].std(0) 280 | 281 | ## Save means 282 | paths=f'./data/wiki/wiki_nips/meanstd.pkl' 283 | if os.path.isfile(paths) == False: 284 | with open(paths, 'wb') as f: 285 | pickle.dump([mean_data,std_data],f) 286 | 287 | ## Save sequences 288 | paths=f'./data/wiki/wiki_nips/data.pkl' 289 | if os.path.isfile(paths) == False: 290 | with open(paths, 'wb') as f: 291 | pickle.dump([main_data,mask_data],f) 292 | 293 | class Forecasting_Dataset(Dataset): 294 | """ 295 | A PyTorch Dataset class for loading and preparing forecasting data. 296 | 297 | Parameters: 298 | - dataset_name (str): The name of the dataset. One of the following: electricity, solar, traffic, taxi or wiki. 299 | - train_length, skip_length, valid_length, test_length, pred_length, history_length (int): Parameters defining the dataset structure and lengths of different segments as described in the processing functions. 300 | - is_train (int): Indicator of the dataset split (0 for test, 1 for train, 2 for valid). 301 | """ 302 | def __init__(self, dataset_name, train_length, skip_length, valid_length, test_length, pred_length, history_length, is_train): 303 | self.history_length = history_length 304 | self.pred_length = pred_length 305 | self.test_length = test_length 306 | self.valid_length = valid_length 307 | self.data_type = dataset_name 308 | self.seq_length = self.pred_length+self.history_length 309 | 310 | if dataset_name == 'taxi': 311 | preprocess_taxi(train_length, test_length, skip_length, history_length) 312 | elif dataset_name == 'wiki': 313 | preprocess_wiki(train_length, test_length, skip_length, history_length) 314 | else: 315 | preprocess_dataset(dataset_name, train_length, test_length, skip_length, history_length) 316 | 317 | paths = f'./data/{self.data_type}/{self.data_type}_nips/data.pkl' 318 | mean_path = f'./data/{self.data_type}/{self.data_type}_nips/meanstd.pkl' 319 | with open(paths, 'rb') as f: 320 | self.main_data,self.mask_data=pickle.load(f) 321 | with open(mean_path, 'rb') as f: 322 | self.mean_data,self.std_data=pickle.load(f) 323 | 324 | self.main_data = (self.main_data - self.mean_data) / np.maximum(1e-5,self.std_data) 325 | 326 | data_length = len(self.main_data) 327 | if is_train == 0: #test 328 | start = data_length - self.seq_length - self.test_length + self.pred_length 329 | end = data_length - self.seq_length + self.pred_length 330 | self.use_index = np.arange(start,end,self.pred_length) 331 | print('Test', start, end) 332 | 333 | if is_train == 2: #valid 334 | start = data_length - self.seq_length - self.valid_length - self.test_length + self.pred_length 335 | end = data_length - self.seq_length - self.test_length + self.pred_length 336 | self.use_index = np.arange(start,end,self.pred_length) 337 | print('Val', start, end) 338 | if is_train == 1: 339 | start = 0 340 | end = data_length - self.seq_length - self.valid_length - self.test_length + 1 341 | self.use_index = np.arange(start,end,1) 342 | print('Train', start, end) 343 | 344 | 345 | def __getitem__(self, orgindex): 346 | """ 347 | Gets the MTS at the specified index. 348 | 349 | Parameters: 350 | - orgindex (int): The index of the MTS (index of the start timestamp of the sequence). 351 | 352 | Returns: 353 | - dict: A dictionary containing 'observed_data', 'observed_mask', 'gt_mask', 'timepoints', and 'feature_id'. 354 | """ 355 | index = self.use_index[orgindex] 356 | target_mask = self.mask_data[index:index+self.seq_length].copy() 357 | target_mask[-self.pred_length:] = 0. 358 | s = { 359 | 'observed_data': self.main_data[index:index+self.seq_length], 360 | 'observed_mask': self.mask_data[index:index+self.seq_length], 361 | 'gt_mask': target_mask, 362 | 'timepoints': np.arange(self.seq_length) * 1.0, 363 | 'feature_id': np.arange(self.main_data.shape[1]) * 1.0, 364 | } 365 | 366 | return s 367 | 368 | def __len__(self): 369 | """ 370 | Returns the total number of samples in the dataset. 371 | 372 | Returns: 373 | - int: The total number of samples. 374 | """ 375 | return len(self.use_index) 376 | 377 | 378 | def get_dataloader_forecasting(dataset_name, train_length, skip_length, valid_length=24*5, test_length =24*7, pred_length=24, history_length=168, batch_size=8, device='cuda:0'): 379 | """ 380 | Prepares DataLoader objects for the forecasting datasets. 381 | 382 | Parameters: 383 | - dataset_name (str): The name of the dataset. 384 | - train_length, skip_length, valid_length, test_length, pred_length, history_length, batch_size (int): Various parameters defining dataset and DataLoader configurations. 385 | - device (str): The device to use for loading tensors. 386 | 387 | Returns: 388 | - Tuple[DataLoader, DataLoader, DataLoader, Tensor, Tensor]: Training, validation, and testing DataLoaders, along with scale and mean scale tensors used for normalization. 389 | """ 390 | 391 | train_dataset = Forecasting_Dataset(dataset_name, train_length, skip_length, valid_length, test_length, pred_length, history_length, is_train=1) 392 | train_loader = DataLoader( 393 | train_dataset, batch_size=batch_size, shuffle=True) 394 | 395 | valid_dataset = Forecasting_Dataset(dataset_name, train_length, skip_length, valid_length, test_length, pred_length, history_length, is_train=2) 396 | valid_loader = DataLoader( 397 | valid_dataset, batch_size=batch_size, shuffle=False) 398 | 399 | test_dataset = Forecasting_Dataset(dataset_name, train_length, skip_length, valid_length, test_length, pred_length, history_length, is_train=0) 400 | test_loader = DataLoader( 401 | test_dataset, batch_size=batch_size, shuffle=False) 402 | scaler = torch.from_numpy(train_dataset.std_data).to(device).float() 403 | mean_scaler = torch.from_numpy(train_dataset.mean_data).to(device).float() 404 | return train_loader, valid_loader, test_loader, scaler, mean_scaler -------------------------------------------------------------------------------- /src/data_loader/anomaly_detection_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from torch.utils.data import Dataset, DataLoader 5 | from sklearn.preprocessing import StandardScaler 6 | 7 | 8 | root_path = './data/anomaly_detection/' 9 | class PSMSegLoader(Dataset): 10 | """ 11 | A PyTorch Dataset class for loading and preprocessing the PSM dataset for anomaly detection. 12 | The dataset is segmented based on window size and step (for sliding). 13 | 14 | Parameters: 15 | - win_size: The size of the window to segment the dataset. 16 | - step: The step size for sliding window. 17 | - flag: Indicates the dataset split to use ('train', 'val', 'test'). 18 | 19 | Attributes: 20 | - train: Training data after scaling. 21 | - val: Validation data after scaling. 22 | - test: Test data after scaling. 23 | - test_labels: Labels for the test data. 24 | """ 25 | def __init__(self, win_size, step=100, flag="train"): 26 | self.flag = flag 27 | self.step = step 28 | self.win_size = win_size 29 | self.scaler = StandardScaler() 30 | self.root_path = root_path+'PSM/' 31 | data = pd.read_csv(os.path.join(self.root_path, 'train.csv')) 32 | data = data.values[:, 1:] 33 | data = np.nan_to_num(data) 34 | self.scaler.fit(data) 35 | data = self.scaler.transform(data) 36 | test_data = pd.read_csv(os.path.join(self.root_path, 'test.csv')) 37 | test_data = test_data.values[:, 1:] 38 | test_data = np.nan_to_num(test_data) 39 | self.test = self.scaler.transform(test_data) 40 | self.train = data 41 | data_len = len(self.train) 42 | self.val = self.train[(int)(data_len * 0.8):] 43 | self.test_labels = pd.read_csv(os.path.join(self.root_path, 'test_label.csv')).values[:, 1:] 44 | print("test:", self.test.shape) 45 | print("train:", self.train.shape) 46 | 47 | def __len__(self): 48 | if self.flag == "train": 49 | return (self.train.shape[0] - self.win_size) // self.step + 1 50 | elif (self.flag == 'val'): 51 | return (self.val.shape[0] - self.win_size) // self.step + 1 52 | elif (self.flag == 'test'): 53 | return (self.test.shape[0] - self.win_size) // self.step + 1 54 | else: 55 | return (self.test.shape[0] - self.win_size) // self.win_size + 1 56 | 57 | def __getitem__(self, index): 58 | index = index * self.step 59 | if self.flag == "train": 60 | data_point, label = np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 61 | elif (self.flag == 'val'): 62 | data_point, label = np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 63 | elif (self.flag == 'test'): 64 | data_point, label = np.float32(self.test[index:index + self.win_size]), np.float32( 65 | self.test_labels[index:index + self.win_size]) 66 | else: 67 | data_point, label = np.float32(self.test[ 68 | index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32( 69 | self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]) 70 | s = { 71 | 'observed_data': data_point, 72 | 'observed_mask': np.ones_like(data_point), 73 | 'gt_mask': np.ones_like(data_point), 74 | 'timepoints': np.arange(self.win_size) * 1.0, 75 | 'feature_id': np.arange(data_point.shape[1]) * 1.0, 76 | 'label': label, 77 | } 78 | 79 | return s 80 | 81 | class MSLSegLoader(Dataset): 82 | """ 83 | A PyTorch Dataset class for loading and preprocessing the MSL dataset for anomaly detection. 84 | The dataset is segmented based on window size and step (for sliding). 85 | 86 | Parameters: 87 | - win_size: The size of the window to segment the dataset. 88 | - step: The step size for sliding window. 89 | - flag: Indicates the dataset split to use ('train', 'val', 'test'). 90 | 91 | Attributes: 92 | - train: Training data after scaling. 93 | - val: Validation data after scaling. 94 | - test: Test data after scaling. 95 | - test_labels: Labels for the test data. 96 | """ 97 | def __init__(self, win_size, step=100, flag="train"): 98 | self.flag = flag 99 | self.step = step 100 | self.win_size = win_size 101 | self.root_path = root_path+'MSL/' 102 | self.scaler = StandardScaler() 103 | data = np.load(os.path.join(self.root_path, "MSL_train.npy")) 104 | self.scaler.fit(data) 105 | data = self.scaler.transform(data) 106 | test_data = np.load(os.path.join(self.root_path, "MSL_test.npy")) 107 | self.test = self.scaler.transform(test_data) 108 | self.train = data 109 | data_len = len(self.train) 110 | self.val = self.train[(int)(data_len * 0.8):] 111 | self.test_labels = np.load(os.path.join(self.root_path, "MSL_test_label.npy")) 112 | print("test:", self.test.shape) 113 | print("train:", self.train.shape) 114 | 115 | def __len__(self): 116 | if self.flag == "train": 117 | return (self.train.shape[0] - self.win_size) // self.step + 1 118 | elif (self.flag == 'val'): 119 | return (self.val.shape[0] - self.win_size) // self.step + 1 120 | elif (self.flag == 'test'): 121 | return (self.test.shape[0] - self.win_size) // self.step + 1 122 | else: 123 | return (self.test.shape[0] - self.win_size) // self.win_size + 1 124 | 125 | def __getitem__(self, index): 126 | index = index * self.step 127 | if self.flag == "train": 128 | data_point, label = np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 129 | elif (self.flag == 'val'): 130 | data_point, label = np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 131 | elif (self.flag == 'test'): 132 | data_point, label = np.float32(self.test[index:index + self.win_size]), np.float32( 133 | self.test_labels[index:index + self.win_size]) 134 | else: 135 | data_point = np.float32(self.test[ 136 | index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32( 137 | self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]) 138 | 139 | s = { 140 | 'observed_data': data_point, 141 | 'observed_mask': np.ones_like(data_point), 142 | 'gt_mask': np.ones_like(data_point), 143 | 'timepoints': np.arange(self.win_size) * 1.0, 144 | 'feature_id': np.arange(data_point.shape[1]) * 1.0, 145 | 'label': label, 146 | } 147 | 148 | return s 149 | 150 | 151 | class SMAPSegLoader(Dataset): 152 | """ 153 | A PyTorch Dataset class for loading and preprocessing the SMAP dataset for anomaly detection. 154 | The dataset is segmented based on window size and step (for sliding). 155 | 156 | Parameters: 157 | - win_size: The size of the window to segment the dataset. 158 | - step: The step size for sliding window. 159 | - flag: Indicates the dataset split to use ('train', 'val', 'test'). 160 | 161 | Attributes: 162 | - train: Training data after scaling. 163 | - val: Validation data after scaling. 164 | - test: Test data after scaling. 165 | - test_labels: Labels for the test data. 166 | """ 167 | def __init__(self, win_size, step=100, flag="train"): 168 | self.flag = flag 169 | self.step = step 170 | self.win_size = win_size 171 | self.root_path = root_path+'SMAP/' 172 | self.scaler = StandardScaler() 173 | data = np.load(os.path.join(self.root_path, "SMAP_train.npy")) 174 | self.scaler.fit(data) 175 | data = self.scaler.transform(data) 176 | test_data = np.load(os.path.join(self.root_path, "SMAP_test.npy")) 177 | self.test = self.scaler.transform(test_data) 178 | self.train = data 179 | data_len = len(self.train) 180 | self.val = self.train[(int)(data_len * 0.8):] 181 | self.test_labels = np.load(os.path.join(self.root_path, "SMAP_test_label.npy")) 182 | print("test:", self.test.shape) 183 | print("train:", self.train.shape) 184 | 185 | def __len__(self): 186 | 187 | if self.flag == "train": 188 | return (self.train.shape[0] - self.win_size) // self.step + 1 189 | elif (self.flag == 'val'): 190 | return (self.val.shape[0] - self.win_size) // self.step + 1 191 | elif (self.flag == 'test'): 192 | return (self.test.shape[0] - self.win_size) // self.step + 1 193 | else: 194 | return (self.test.shape[0] - self.win_size) // self.win_size + 1 195 | 196 | def __getitem__(self, index): 197 | index = index * self.step 198 | if self.flag == "train": 199 | data_point, label = np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 200 | elif (self.flag == 'val'): 201 | data_point, label = np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 202 | elif (self.flag == 'test'): 203 | data_point, label = np.float32(self.test[index:index + self.win_size]), np.float32( 204 | self.test_labels[index:index + self.win_size]) 205 | else: 206 | data_point, label = np.float32(self.test[ 207 | index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32( 208 | self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]) 209 | 210 | s = { 211 | 'observed_data': data_point, 212 | 'observed_mask': np.ones_like(data_point), 213 | 'gt_mask': np.ones_like(data_point), 214 | 'timepoints': np.arange(self.win_size) * 1.0, 215 | 'feature_id': np.arange(data_point.shape[1]) * 1.0, 216 | 'label': label, 217 | } 218 | 219 | return s 220 | 221 | 222 | class SMDSegLoader(Dataset): 223 | """ 224 | A PyTorch Dataset class for loading and preprocessing the SMD dataset for anomaly detection. 225 | The dataset is segmented based on window size and step (for sliding). 226 | 227 | Parameters: 228 | - win_size: The size of the window to segment the dataset. 229 | - step: The step size for sliding window. 230 | - flag: Indicates the dataset split to use ('train', 'val', 'test'). 231 | 232 | Attributes: 233 | - train: Training data after scaling. 234 | - val: Validation data after scaling. 235 | - test: Test data after scaling. 236 | - test_labels: Labels for the test data. 237 | """ 238 | def __init__(self, win_size, step=100, flag="train"): 239 | self.flag = flag 240 | self.step = step 241 | self.win_size = win_size 242 | self.root_path = root_path+'SMD/' 243 | self.scaler = StandardScaler() 244 | data = np.load(os.path.join(self.root_path, "SMD_train.npy")) 245 | self.scaler.fit(data) 246 | data = self.scaler.transform(data) 247 | test_data = np.load(os.path.join(self.root_path, "SMD_test.npy")) 248 | self.test = self.scaler.transform(test_data) 249 | self.train = data 250 | data_len = len(self.train) 251 | self.val = self.train[(int)(data_len * 0.8):] 252 | self.test_labels = np.load(os.path.join(self.root_path, "SMD_test_label.npy")) 253 | 254 | def __len__(self): 255 | if self.flag == "train": 256 | return (self.train.shape[0] - self.win_size) // self.step + 1 257 | elif (self.flag == 'val'): 258 | return (self.val.shape[0] - self.win_size) // self.step + 1 259 | elif (self.flag == 'test'): 260 | return (self.test.shape[0] - self.win_size) // self.step + 1 261 | else: 262 | return (self.test.shape[0] - self.win_size) // self.win_size + 1 263 | 264 | def __getitem__(self, index): 265 | index = index * self.step 266 | if self.flag == "train": 267 | data_point, label = np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 268 | elif (self.flag == 'val'): 269 | data_point, label = np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 270 | elif (self.flag == 'test'): 271 | data_point, label = np.float32(self.test[index:index + self.win_size]), np.float32( 272 | self.test_labels[index:index + self.win_size]) 273 | else: 274 | data_point, label = np.float32(self.test[ 275 | index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32( 276 | self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]) 277 | 278 | s = { 279 | 'observed_data': data_point, 280 | 'observed_mask': np.ones_like(data_point), 281 | 'gt_mask': np.ones_like(data_point), 282 | 'timepoints': np.arange(self.win_size) * 1.0, 283 | 'feature_id': np.arange(data_point.shape[1]) * 1.0, 284 | 'label': label, 285 | } 286 | 287 | return s 288 | 289 | 290 | class SWATSegLoader(Dataset): 291 | """ 292 | A PyTorch Dataset class for loading and preprocessing the SWaT dataset for anomaly detection. 293 | The dataset is segmented based on window size and step (for sliding). 294 | 295 | Parameters: 296 | - win_size: The size of the window to segment the dataset. 297 | - step: The step size for sliding window. 298 | - flag: Indicates the dataset split to use ('train', 'val', 'test'). 299 | 300 | Attributes: 301 | - train: Training data after scaling. 302 | - val: Validation data after scaling. 303 | - test: Test data after scaling. 304 | - test_labels: Labels for the test data. 305 | """ 306 | def __init__(self, win_size, step=100, flag="train"): 307 | self.flag = flag 308 | self.step = step 309 | self.win_size = win_size 310 | self.root_path = root_path+'SWaT/' 311 | self.scaler = StandardScaler() 312 | 313 | train_data = pd.read_csv(os.path.join(self.root_path, 'swat_train2.csv')) 314 | test_data = pd.read_csv(os.path.join(self.root_path, 'swat2.csv')) 315 | labels = test_data.values[:, -1:] 316 | train_data = train_data.values[:, :-1] 317 | test_data = test_data.values[:, :-1] 318 | 319 | self.scaler.fit(train_data) 320 | train_data = self.scaler.transform(train_data) 321 | test_data = self.scaler.transform(test_data) 322 | self.train = train_data 323 | self.test = test_data 324 | data_len = len(self.train) 325 | self.val = self.train[(int)(data_len * 0.8):] 326 | self.test_labels = labels 327 | print("test:", self.test.shape) 328 | print("train:", self.train.shape) 329 | 330 | def __len__(self): 331 | """ 332 | Number of images in the object dataset. 333 | """ 334 | if self.flag == "train": 335 | return (self.train.shape[0] - self.win_size) // self.step + 1 336 | elif (self.flag == 'val'): 337 | return (self.val.shape[0] - self.win_size) // self.step + 1 338 | elif (self.flag == 'test'): 339 | return (self.test.shape[0] - self.win_size) // self.step + 1 340 | else: 341 | return (self.test.shape[0] - self.win_size) // self.win_size + 1 342 | 343 | def __getitem__(self, index): 344 | index = index * self.step 345 | if self.flag == "train": 346 | data_point, label = np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 347 | elif (self.flag == 'val'): 348 | data_point, label = np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size]) 349 | elif (self.flag == 'test'): 350 | data_point, label = np.float32(self.test[index:index + self.win_size]), np.float32( 351 | self.test_labels[index:index + self.win_size]) 352 | else: 353 | data_point, label = np.float32(self.test[ 354 | index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32( 355 | self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]) 356 | 357 | s = { 358 | 'observed_data': data_point, 359 | 'observed_mask': np.ones_like(data_point), 360 | 'gt_mask': np.ones_like(data_point), 361 | 'timepoints': np.arange(self.win_size) * 1.0, 362 | 'feature_id': np.arange(data_point.shape[1]) * 1.0, 363 | 'label': label, 364 | } 365 | 366 | return s 367 | 368 | 369 | def anomaly_detection_dataloader(dataset_name, win_size=100, batch_size=128): 370 | """ 371 | Creates data loaders for training, validation, and testing datasets for anomaly detection. 372 | 373 | Parameters: 374 | - dataset_name: The name of the dataset to use ('SMAP', 'PSM', 'MSL', 'SMD', 'SWAT'). 375 | - win_size: The window size for segmenting the dataset. 376 | - batch_size: The size of the batch for data loading. 377 | 378 | Returns: 379 | - A tuple of DataLoader objects for the training, validation, and testing datasets. 380 | """ 381 | 382 | if dataset_name == "SMAP": 383 | Dataset = SMAPSegLoader 384 | elif dataset_name == "PSM": 385 | Dataset = PSMSegLoader 386 | elif dataset_name == "MSL": 387 | Dataset = MSLSegLoader 388 | elif dataset_name == "SMD": 389 | Dataset = SMDSegLoader 390 | elif dataset_name == "SWAT": 391 | Dataset = SWATSegLoader 392 | 393 | train_dataset = Dataset(win_size=win_size, flag='train') 394 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) 395 | 396 | valid_dataset = Dataset(win_size=win_size, flag='val') 397 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=False) 398 | 399 | test_dataset = Dataset(win_size=win_size, flag='test') 400 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) 401 | print(len(train_dataset)) 402 | 403 | return train_dataloader, valid_dataloader, test_dataloader -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.optim import Adam 4 | from tqdm import tqdm 5 | import pickle 6 | import os 7 | import random 8 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score 9 | import torch.nn.functional as F 10 | from sklearn import metrics 11 | import time 12 | 13 | 14 | 15 | from utils.metrics import calc_quantile_CRPS_sum, calc_quantile_CRPS, save_roc_curve 16 | 17 | def set_seed(seed: int): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | torch.backends.cudnn.determinstic = True 25 | 26 | def gsutil_cp(src_path: str, dst_path: str): 27 | exec_result = os.system(f"gsutil cp -R {src_path} {dst_path}") 28 | if exec_result != 0: 29 | error_msg = f"gsutil_cp: Failed to copy file from {src_path} to {dst_path}" 30 | raise OSError(error_msg) 31 | else: 32 | print(f"gsutil_cp: copied file from {src_path} to {dst_path}") 33 | return exec_result, src_path, dst_path 34 | 35 | 36 | # Function to wait for a file to exist 37 | def wait_for_file(file_path, timeout=60): 38 | start_time = time.time() 39 | while not os.path.exists(file_path): 40 | time.sleep(1) # Wait for 1 second before checking again 41 | if time.time() - start_time > timeout: 42 | raise TimeoutError(f"File {file_path} not found after {timeout} seconds.") 43 | print(f"File {file_path} found. Continuing script...") 44 | 45 | def adjustment(gt, pred): 46 | anomaly_state = False 47 | for i in range(len(gt)): 48 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 49 | anomaly_state = True 50 | for j in range(i, 0, -1): 51 | if gt[j] == 0: 52 | break 53 | else: 54 | if pred[j] == 0: 55 | pred[j] = 1 56 | for j in range(i, len(gt)): 57 | if gt[j] == 0: 58 | break 59 | else: 60 | if pred[j] == 0: 61 | pred[j] = 1 62 | elif gt[i] == 0: 63 | anomaly_state = False 64 | if anomaly_state: 65 | pred[i] = 1 66 | return gt, pred 67 | 68 | def train( 69 | model, 70 | config, 71 | train_loader, 72 | valid_loader=None, 73 | test_loader=None, 74 | valid_epoch_interval=5, 75 | eval_epoch_interval=500, 76 | foldername="", 77 | mode = 'pretraining', 78 | scaler=0, 79 | mean_scaler=1, 80 | nsample=100, 81 | save_samples = False, 82 | physionet_classification=False, 83 | normalize_for_ad=False 84 | ): 85 | optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=1e-6) 86 | if foldername != "": 87 | output_path = foldername + "/model.pth" 88 | loss_path = foldername + "/losses.txt" 89 | p1 = int(0.75 * config["epochs"]) 90 | p2 = int(0.9 * config["epochs"]) 91 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 92 | optimizer, milestones=[p1, p2], gamma=0.1 93 | ) 94 | 95 | best_valid_loss = 1e10 96 | for epoch_no in range(config["epochs"]): 97 | avg_loss = 0 98 | model.train() 99 | with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it: 100 | for batch_no, train_batch in enumerate(it, start=1): 101 | optimizer.zero_grad() 102 | 103 | loss = model(train_batch,task=mode, normalize_for_ad=normalize_for_ad) 104 | loss.backward() 105 | avg_loss += loss.item() 106 | optimizer.step() 107 | it.set_postfix( 108 | ordered_dict={ 109 | "avg_epoch_loss": avg_loss / batch_no, 110 | "epoch": epoch_no, 111 | }, 112 | refresh=False, 113 | ) 114 | if foldername != "": 115 | ## Save Losses in txt File 116 | with open(loss_path, "a") as file: 117 | file.write('avg_epoch_loss: '+ str(avg_loss / batch_no) + ", epoch= "+ str(epoch_no) + "\n") 118 | lr_scheduler.step() 119 | if valid_loader is not None and (epoch_no + 1) % valid_epoch_interval == 0: 120 | model.eval() 121 | avg_loss_valid = 0 122 | with torch.no_grad(): 123 | with tqdm(valid_loader, mininterval=5.0, maxinterval=50.0) as it: 124 | for batch_no, valid_batch in enumerate(it, start=1): 125 | loss = model(valid_batch, is_train=0, normalize_for_ad=normalize_for_ad) 126 | avg_loss_valid += loss.item() 127 | it.set_postfix( 128 | ordered_dict={ 129 | "valid_avg_epoch_loss": avg_loss_valid / batch_no, 130 | "epoch": epoch_no, 131 | }, 132 | refresh=False, 133 | ) 134 | if best_valid_loss > avg_loss_valid: 135 | best_valid_loss = avg_loss_valid 136 | with open(loss_path, "a") as file: 137 | file.write('best loss is updated to: '+ str(avg_loss_valid / batch_no) + "at epoch= "+ str(epoch_no) + "\n") 138 | print( 139 | "\n best loss is updated to ", 140 | avg_loss_valid / batch_no, 141 | "at", 142 | epoch_no, 143 | ) 144 | 145 | if mode == 'pretraining' and test_loader is not None and (epoch_no + 1) % eval_epoch_interval == 0: 146 | current_checkpoint = (epoch_no + 1) // eval_epoch_interval 147 | previous_checkpoint_path = foldername + "checkpoint_"+str(current_checkpoint-1)+"/model.pth" 148 | checkpoint_folder = foldername + "checkpoint_"+str(current_checkpoint) 149 | os.makedirs(checkpoint_folder, exist_ok=True) 150 | torch.save(model.state_dict(), checkpoint_folder+"/model.pth") 151 | if os.path.exists(previous_checkpoint_path): 152 | os.remove(previous_checkpoint_path) 153 | print(f"Checkpoint '{previous_checkpoint_path}' has been deleted.") 154 | else: 155 | print(f"No checkpoint found at '{previous_checkpoint_path}'.") 156 | model.eval() 157 | 158 | evaluate(model, test_loader, nsample=nsample, scaler=scaler, mean_scaler=mean_scaler, foldername=checkpoint_folder, save_samples = save_samples, physionet_classification=physionet_classification, normalize_for_ad=normalize_for_ad) 159 | 160 | if foldername != "": 161 | torch.save(model.state_dict(), output_path) 162 | 163 | def finetune(model, 164 | config, 165 | train_loader, 166 | criterion, 167 | foldername="", 168 | task = 'classification', 169 | normalize_for_ad=False): 170 | optimizer = Adam(model.parameters(), lr=0.0001, weight_decay=1e-6) 171 | if foldername != "": 172 | output_path = foldername + "/model.pth" 173 | loss_path = foldername + "/losses.txt" 174 | ### Include loss in train_finetuning 175 | for epoch_no in range(config["epochs"]): 176 | avg_loss = 0 177 | model.train() 178 | with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it: 179 | for batch_no, train_batch in enumerate(it, start=1): 180 | optimizer.zero_grad() 181 | outputs, loss = model.forward_finetuning(batch=train_batch, criterion=criterion, task=task, normalize_for_ad=normalize_for_ad) 182 | loss.backward() 183 | avg_loss += loss.item() 184 | optimizer.step() 185 | it.set_postfix( 186 | ordered_dict={ 187 | "avg_epoch_loss": avg_loss / batch_no, 188 | "epoch": epoch_no, 189 | }, 190 | refresh=False, 191 | ) 192 | ## Save Losses in txt File 193 | with open(loss_path, "a") as file: 194 | file.write('avg_epoch_loss: '+ str(avg_loss / batch_no) + ", epoch= "+ str(epoch_no) + "\n") 195 | 196 | if foldername != "": 197 | torch.save(model.state_dict(), output_path) 198 | 199 | 200 | 201 | 202 | def evaluate(model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername="", save_samples = False, physionet_classification=False, set_type='Train', normalize_for_ad=False): 203 | 204 | loss_path = foldername + "/losses.txt" 205 | with torch.no_grad(): 206 | model.eval() 207 | mse_total = 0 208 | mae_total = 0 209 | evalpoints_total = 0 210 | 211 | all_target = [] 212 | all_observed_point = [] 213 | all_observed_time = [] 214 | all_evalpoint = [] 215 | all_generated_samples = [] 216 | all_labels = [] 217 | with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it: 218 | for batch_no, test_batch in enumerate(it, start=1): 219 | output = model.evaluate(batch=test_batch, n_samples=nsample, normalize_for_ad=normalize_for_ad) 220 | if physionet_classification: 221 | samples, c_target, eval_points, observed_points, observed_time, labels = output 222 | else: 223 | samples, c_target, eval_points, observed_points, observed_time = output 224 | samples = samples.permute(0, 1, 3, 2) # (B,nsample,L,K) 225 | c_target = c_target.permute(0, 2, 1) # (B,L,K) 226 | eval_points = eval_points.permute(0, 2, 1) 227 | observed_points = observed_points.permute(0, 2, 1) 228 | 229 | samples_median = samples.median(dim=1) 230 | all_target.append(c_target) 231 | all_evalpoint.append(eval_points) 232 | all_observed_point.append(observed_points) 233 | all_observed_time.append(observed_time) 234 | all_generated_samples.append(samples) 235 | if physionet_classification: 236 | all_labels.extend(labels.tolist()) 237 | 238 | mse_current = ( 239 | ((samples_median.values - c_target) * eval_points) ** 2 240 | ) * (scaler ** 2) 241 | mae_current = ( 242 | torch.abs((samples_median.values - c_target) * eval_points) 243 | ) * scaler 244 | 245 | mse_total += mse_current.sum().item() 246 | mae_total += mae_current.sum().item() 247 | evalpoints_total += eval_points.sum().item() 248 | 249 | it.set_postfix( 250 | ordered_dict={ 251 | "rmse_total": np.sqrt(mse_total / evalpoints_total), 252 | "mae_total": mae_total / evalpoints_total, 253 | "batch_no": batch_no, 254 | }, 255 | refresh=True, 256 | ) 257 | all_target = torch.cat(all_target, dim=0) 258 | all_evalpoint = torch.cat(all_evalpoint, dim=0) 259 | all_observed_point = torch.cat(all_observed_point, dim=0) 260 | all_observed_time = torch.cat(all_observed_time, dim=0) 261 | all_generated_samples = torch.cat(all_generated_samples, dim=0) 262 | 263 | CRPS = calc_quantile_CRPS( 264 | all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler 265 | ) 266 | CRPS_sum = calc_quantile_CRPS_sum( 267 | all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler 268 | ) 269 | print("RMSE:", np.sqrt(mse_total / evalpoints_total)) 270 | print("MAE:", mae_total / evalpoints_total) 271 | print("CRPS:", CRPS) 272 | print("CRPS-sum:", CRPS_sum) 273 | print("MSE:", mse_total/evalpoints_total) 274 | 275 | with open(loss_path, "a") as file: 276 | file.write("RMSE:"+ str(np.sqrt(mse_total / evalpoints_total)) + "\n") 277 | file.write("MAE:"+ str(mae_total / evalpoints_total) + "\n") 278 | file.write("CRPS:"+ str(CRPS) + "\n") 279 | file.write("CRPS-sum:"+ str(CRPS_sum) + "\n") 280 | file.write("MSE:"+ str(mse_total/evalpoints_total) + "\n") 281 | 282 | print(len(all_labels)) 283 | if save_samples and physionet_classification: 284 | with open( 285 | foldername + f"/generated_outputs_{set_type}_nsample" + str(nsample) + ".pk", "wb" 286 | ) as f: 287 | 288 | 289 | pickle.dump( 290 | [ 291 | all_generated_samples, 292 | all_target, 293 | all_evalpoint, 294 | all_observed_point, 295 | all_observed_time, 296 | all_labels, 297 | scaler, 298 | mean_scaler, 299 | ], 300 | f, 301 | ) 302 | 303 | elif save_samples: 304 | with open( 305 | foldername + "/generated_outputs_nsample" + str(nsample) + ".pk", "wb" 306 | ) as f: 307 | 308 | 309 | pickle.dump( 310 | [ 311 | all_generated_samples, 312 | all_target, 313 | all_evalpoint, 314 | all_observed_point, 315 | all_observed_time, 316 | scaler, 317 | mean_scaler, 318 | ], 319 | f, 320 | ) 321 | with open( 322 | foldername + "/result_nsample" + str(nsample) + ".pk", "wb" 323 | ) as f: 324 | pickle.dump( 325 | [ 326 | np.sqrt(mse_total / evalpoints_total), 327 | mae_total / evalpoints_total, 328 | CRPS, 329 | CRPS_sum, 330 | mse_total / evalpoints_total, 331 | ], 332 | f, 333 | ) 334 | 335 | def evaluate_finetuning(model, train_loader, test_loader, foldername="", anomaly_ratio = 1, save_embeddings = False, task='classification', normalize_for_ad=False): 336 | attens_energy = [] 337 | train_energies = [] 338 | test_labels = [] 339 | all_correct = 0 340 | all_total = 0 341 | all_outputs = [] 342 | all_classes = [] 343 | with torch.no_grad(): 344 | model.eval() 345 | 346 | with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it: 347 | for batch_no, test_batch in enumerate(it, start=1): 348 | outputs, result = model.evaluate_finetuned_model(batch=test_batch, task=task, normalize_for_ad=normalize_for_ad) 349 | if task == 'classification': 350 | all_correct+=result[0] 351 | all_total+=result[1] 352 | all_outputs.append(outputs[0]) 353 | all_classes.append(outputs[1]) 354 | elif task == 'anomaly_detection': 355 | attens_energy.append(result) 356 | test_labels.append(test_batch["label"]) 357 | 358 | if task == 'anomaly_detection': 359 | with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it: 360 | for batch_no, batch in enumerate(it, start=1): 361 | # reconstruction 362 | outputs, result = model.evaluate_finetuned_model(batch, task=task, normalize_for_ad=normalize_for_ad) 363 | train_energies.append(result) 364 | #print('Output shape', outputs.shape) 365 | #print('Score shape', score.shape) 366 | train_energies = np.concatenate(train_energies, axis=0).reshape(-1) 367 | train_energy = np.array(train_energies) 368 | 369 | 370 | # (2) find the threshold 371 | attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1) 372 | test_energy = np.array(attens_energy) 373 | combined_energy = np.concatenate([train_energy, test_energy], axis=0) 374 | threshold = np.percentile(combined_energy, 100 - anomaly_ratio) 375 | #threshold = 8 376 | print("Threshold :", threshold) 377 | print("attens_energy :", attens_energy.shape) 378 | # (3) evaluation on the test set 379 | pred = (test_energy > threshold).astype(int) 380 | test_labels = np.concatenate(test_labels, axis=0).reshape(-1) 381 | test_labels = np.array(test_labels) 382 | gt = test_labels.astype(int) 383 | 384 | print("pred: ", pred.shape) 385 | print("gt: ", gt.shape) 386 | 387 | # (4) detection adjustment 388 | gt, pred = adjustment(gt, pred) 389 | 390 | pred = np.array(pred) 391 | gt = np.array(gt) 392 | print("pred: ", pred.shape) 393 | print("gt: ", gt.shape) 394 | 395 | accuracy = accuracy_score(gt, pred) 396 | precision, recall, f_score, support = precision_recall_fscore_support(gt, pred, average='binary') 397 | print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format( 398 | accuracy, precision, 399 | recall, f_score)) 400 | with open(foldername+'results.txt', "a") as file: 401 | file.write("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format( 402 | accuracy, precision, 403 | recall, f_score)) 404 | 405 | 406 | elif task == 'classification': 407 | probabilities = F.softmax(torch.cat(all_outputs, dim=0), dim=1) 408 | fpr, tpr, thresholds = metrics.roc_curve(torch.cat(all_classes, dim=0).cpu().numpy(), probabilities[:, 1].cpu().numpy()) 409 | auc = roc_auc_score(np.array(torch.cat(all_classes, dim=0).cpu().numpy()), np.array(probabilities[:, 1].cpu().numpy())) 410 | 411 | save_roc_curve(fpr, tpr, foldername) 412 | print('AUC: ', auc) 413 | print('Accuracy: ', all_correct/all_total) 414 | with open(foldername+'results.txt', "a") as file: 415 | file.write("AUC:"+ str(auc) + "\n") 416 | file.write("Accuracy:"+ str(all_correct/all_total) + "\n") 417 | 418 | 419 | 420 | 421 | 422 | 423 | -------------------------------------------------------------------------------- /src/tsde/main_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | from base.denoisingNetwork import diff_Block 8 | from base.mtsEmbedding import embedding_MTS 9 | 10 | from utils.masking_strategies import get_mask_probabilistic_layering, get_mask_equal_p_sample, imputation_mask_batch, pattern_mask_batch, interpolation_mask_batch 11 | 12 | 13 | class TSDE_base(nn.Module): 14 | """ 15 | Base class for TSDE model. 16 | 17 | Attributes: 18 | - device: The device on which the model will run (CPU or CUDA). 19 | - target_dim: number of features in the MTS. 20 | - sample_feat: Whether to sample subset of features during training. 21 | - mix_masking_strategy: Strategy for mixing masks during pretraining. 22 | - time_strategy: Strategy for embedding time points. 23 | - emb_time_dim: Dimension of time embeddings. 24 | - emb_cat_feature_dim: Dimension of categorical feature embeddings. 25 | - mts_emb_dim: Dimension of the MTS embeddings. 26 | - embed_layer: Embedding layer for feature embeddings. 27 | - diffmodel: Model block for diffusion. 28 | - embdmodel: Model for embedding MTS. 29 | - mlp: Multi-layer perceptron for classification tasks. 30 | - conv: Convolutional layer for anomaly detection. 31 | 32 | Methods: 33 | - time_embedding: Generates sinusoidal embeddings for time points. 34 | - get_mts_emb: Generates embeddings for MTS. 35 | - calc_loss: Calculates training loss for a given batch of data. 36 | - calc_loss_valid: Calculates validation loss for a given batch of data. 37 | - impute: Imputes missing values in the time series. 38 | - forward: Forward pass for pretraining, and fine-tuning for imputation, interpolation, and forecasting. 39 | - forward_finetuning: Forward pass for fine-tuning on specific tasks (classification or anomaly detection). 40 | - evaluate_finetuned_model: Evaluates the fine-tuned model for classification and anomaly detection. 41 | - evaluate: Evaluates the model on imputation, interpolation and forecasting. 42 | """ 43 | def __init__(self, target_dim, config, device, sample_feat): 44 | super().__init__() 45 | self.device = device 46 | self.target_dim = target_dim 47 | self.sample_feat=sample_feat 48 | 49 | self.mix_masking_strategy = config["model"]["mix_masking_strategy"] 50 | self.time_strategy = config["model"]["time_strategy"] 51 | 52 | self.emb_time_dim = config["embedding"]["timeemb"] 53 | self.emb_cat_feature_dim = config["embedding"]["featureemb"] 54 | 55 | self.mts_emb_dim = 1+2*config["embedding"]["channels"] 56 | 57 | self.embed_layer = nn.Embedding( 58 | num_embeddings=self.target_dim, embedding_dim=self.emb_cat_feature_dim 59 | ) 60 | 61 | config_diff = config["diffusion"] 62 | config_diff["mts_emb_dim"] = self.mts_emb_dim 63 | config_emb = config["embedding"] 64 | 65 | 66 | 67 | self.diffmodel = diff_Block(config_diff) 68 | self.embdmodel = embedding_MTS(config_emb) 69 | 70 | # parameters for diffusion models 71 | self.num_steps = config_diff["num_steps"] 72 | if config_diff["schedule"] == "quad": 73 | self.beta = np.linspace( 74 | config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps 75 | ) ** 2 76 | elif config_diff["schedule"] == "linear": 77 | self.beta = np.linspace( 78 | config_diff["beta_start"], config_diff["beta_end"], self.num_steps 79 | ) 80 | 81 | self.alpha_hat = 1 - self.beta 82 | self.alpha = np.cumprod(self.alpha_hat) 83 | self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1) 84 | 85 | L = config_emb["num_timestamps"] 86 | K = config_emb["num_feat"] 87 | 88 | # Number of classes for classification experiments 89 | num_classes = config_emb["classes"] 90 | 91 | ## Classifier head 92 | self.mlp = nn.Sequential( 93 | nn.Linear(L*K*self.mts_emb_dim, 256), # Adjust as necessary 94 | nn.SiLU(), 95 | nn.Dropout(0.5), 96 | nn.Linear(256, 256), 97 | nn.SiLU(), 98 | nn.Dropout(0.5), 99 | nn.Linear(256, num_classes), 100 | ) 101 | 102 | ## projection to reconstruct MTS for Anomaly Detection 103 | self.conv = nn.Linear((self.mts_emb_dim-1)*K, K, bias=True) 104 | 105 | 106 | def time_embedding(self, pos, d_model=128): 107 | pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device) 108 | position = pos.unsqueeze(2) 109 | div_term = 1 / torch.pow( 110 | 10000.0, torch.arange(0, d_model, 2).to(self.device) / d_model 111 | ) 112 | pe[:, :, 0::2] = torch.sin(position * div_term) 113 | pe[:, :, 1::2] = torch.cos(position * div_term) 114 | return pe 115 | 116 | 117 | def get_mts_emb(self, observed_tp, cond_mask, x_co, feature_id): 118 | B, K, L = cond_mask.shape 119 | if self.time_strategy == "hawkes": 120 | 121 | time_embed = self.time_embedding(observed_tp, self.emb_time_dim) # (B,L,emb) 122 | elif self.time_strategy == "categorical embeddings": 123 | time_embed = self.time_embed_layer( 124 | torch.arange(L).to(self.device) 125 | ) # (L,emb) 126 | time_embed = time_embed.unsqueeze(0).expand(B, -1, -1) ### (B,L,128) 127 | 128 | if feature_id is None: 129 | feature_embed = self.embed_layer( 130 | torch.arange(self.target_dim).to(self.device) 131 | ) 132 | feature_embed = feature_embed.unsqueeze(0).expand(B, -1, -1) 133 | else: 134 | feature_embed = self.embed_layer( 135 | feature_id 136 | ) # (K,emb) 137 | #print(x_co.shape, time_embed.shape, feature_embed.shape) 138 | cond_embed, xt, xf = self.embdmodel(x_co, time_embed, feature_embed) 139 | 140 | side_mask = cond_mask.unsqueeze(1) # (B,1,K,L) 141 | mts_emb = torch.cat([cond_embed, side_mask], dim=1) 142 | 143 | return mts_emb 144 | 145 | 146 | def calc_loss_valid( 147 | self, observed_data, cond_mask, observed_mask, mts_emb, is_train 148 | ): 149 | loss_sum = 0 150 | for t in range(self.num_steps): # calculate loss for all t 151 | loss = self.calc_loss( 152 | observed_data, cond_mask, observed_mask, mts_emb, is_train, set_t=t 153 | ) 154 | loss_sum += loss.detach() 155 | 156 | return loss_sum / self.num_steps 157 | 158 | def calc_loss( 159 | self, observed_data, cond_mask, observed_mask, mts_emb, is_train, set_t=-1 160 | ): 161 | B, K, L = observed_data.shape 162 | if is_train != 1: # for validation 163 | t = (torch.ones(B) * set_t).long().to(self.device) 164 | else: 165 | t = torch.randint(0, self.num_steps, [B]).to(self.device) 166 | current_alpha = self.alpha_torch[t] # (B,1,1) 167 | 168 | noise = torch.randn_like(observed_data) 169 | noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise 170 | 171 | total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask) 172 | 173 | predicted = self.diffmodel(total_input, mts_emb, t) # (B,K,L) 174 | 175 | target_mask = observed_mask - cond_mask 176 | residual = (noise - predicted) * target_mask 177 | num_eval = target_mask.sum() 178 | loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1) 179 | return loss 180 | 181 | def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): 182 | 183 | noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1) 184 | 185 | return noisy_target 186 | 187 | def impute(self, observed_data, cond_mask, mts_emb, n_samples): 188 | B, K, L = observed_data.shape 189 | 190 | imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device) 191 | 192 | for i in range(n_samples): 193 | 194 | current_sample = torch.randn_like(observed_data) 195 | 196 | for t in range(self.num_steps - 1, -1, -1): 197 | cond_obs = (cond_mask * observed_data).unsqueeze(1) 198 | noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) 199 | 200 | predicted= self.diffmodel(noisy_target, mts_emb, torch.tensor([t]).to(self.device)) 201 | coeff1 = 1 / self.alpha_hat[t] ** 0.5 202 | coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 203 | current_sample = coeff1 * (current_sample - coeff2 * predicted) 204 | 205 | if t > 0: 206 | noise = torch.randn_like(current_sample) 207 | sigma = ( 208 | (1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t] 209 | ) ** 0.5 210 | current_sample += sigma * noise 211 | 212 | imputed_samples[:, i] = current_sample.detach() 213 | return imputed_samples 214 | 215 | def forward(self, batch, is_train=1, task='pretraining', normalize_for_ad=False): 216 | ## is_train = 1 for pretraining and for finetuning but task should be specified and = 0 for evaluation 217 | 218 | ( 219 | observed_data, 220 | observed_mask, 221 | feature_id, 222 | observed_tp, 223 | gt_mask, 224 | for_pattern_mask, 225 | _, 226 | _, 227 | _ 228 | ) = self.process_data(batch, sample_feat=self.sample_feat, train=is_train) 229 | if is_train == 0: 230 | cond_mask = gt_mask 231 | else: 232 | if task == 'pretraining': 233 | if self.mix_masking_strategy == 'equal_p': 234 | cond_mask = get_mask_equal_p_sample(observed_mask) 235 | elif self.mix_masking_strategy == 'probabilistic_layering': 236 | cond_mask = get_mask_probabilistic_layering(observed_mask) 237 | else: 238 | print('Please choose one of the following masking strategy in the config: equal_p, probabilistic_layering') 239 | elif task == 'Imputation': 240 | cond_mask = imputation_mask_batch(observed_mask) 241 | elif task == 'Interpolation': 242 | cond_mask = interpolation_mask_batch(observed_mask) 243 | elif task == 'Imputation with pattern': 244 | cond_mask = pattern_mask_batch(observed_mask) 245 | elif task == 'Forecasting': 246 | cond_mask = gt_mask 247 | else: 248 | print('Please choose the right masking to be applied during finetuning') 249 | 250 | if normalize_for_ad: 251 | ## Normalization from non-stationary Transformer 252 | means = observed_data.mean(2, keepdim=True) 253 | observed_data = observed_data-means 254 | stdev = torch.sqrt(torch.var(observed_data, dim=2, keepdim=True, unbiased=False) + 1e-5) 255 | observed_data /= stdev 256 | 257 | 258 | x_co = (cond_mask * observed_data).unsqueeze(1) 259 | mts_emb = self.get_mts_emb(observed_tp, cond_mask, x_co, feature_id) 260 | 261 | loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid 262 | 263 | return loss_func(observed_data, cond_mask, observed_mask, mts_emb, is_train) 264 | 265 | def forward_finetuning(self, batch, criterion, task='classification', normalize_for_ad=False): 266 | ## task should be either, classification or anomaly_detection 267 | 268 | ( 269 | observed_data, 270 | observed_mask, 271 | feature_id, 272 | observed_tp, 273 | gt_mask, 274 | _, 275 | _, 276 | _, 277 | classes 278 | ) = self.process_data(batch, sample_feat=self.sample_feat, train=False) 279 | 280 | if normalize_for_ad: 281 | ## Normalization from non-stationary Transformer 282 | original_observed_data = observed_data.clone() 283 | means = observed_data.mean(2, keepdim=True) 284 | observed_data = observed_data-means 285 | stdev = torch.sqrt(torch.var(observed_data, dim=2, keepdim=True, unbiased=False) + 1e-5) 286 | observed_data /= stdev 287 | 288 | x_co = (observed_mask * observed_data).unsqueeze(1) 289 | mts_emb = self.get_mts_emb(observed_tp, observed_mask, x_co, feature_id) 290 | 291 | if task == 'classification': 292 | outputs = self.mlp(mts_emb.reshape(mts_emb.shape[0],-1)) 293 | classes = classes.to(self.device) 294 | loss = criterion(outputs, classes) 295 | return outputs, loss 296 | elif task == 'anomaly_detection': 297 | B, C, K, L =mts_emb.shape 298 | #outputs = self.projection(mts_emb.permute(0,2,3,1)).squeeze(-1) 299 | outputs = self.conv(mts_emb[:, :C-1, :, :].reshape(B, (C-1)*K, L).permute(0,2,1)).permute(0,2,1) 300 | if normalize_for_ad: 301 | dec_out = outputs * \ 302 | (stdev[:, :, 0].unsqueeze(2).repeat( 303 | 1, 1, L)) 304 | outputs = dec_out + \ 305 | (means[:, :, 0].unsqueeze(2).repeat( 306 | 1, 1, L)) 307 | 308 | loss = criterion(outputs, original_observed_data) 309 | return outputs, loss 310 | 311 | def evaluate_finetuned_model(self, batch, criterion= nn.MSELoss(reduction='none'), task='classification', normalize_for_ad=False): 312 | 313 | ( 314 | observed_data, 315 | observed_mask, 316 | feature_id, 317 | observed_tp, 318 | gt_mask, 319 | _, 320 | _, 321 | _, 322 | classes 323 | ) = self.process_data(batch, sample_feat=self.sample_feat, train=False) 324 | 325 | with torch.no_grad(): 326 | 327 | if normalize_for_ad: 328 | ## Normalization from non-stationary Transformer 329 | original_observed_data = observed_data.clone() 330 | means = observed_data.mean(2, keepdim=True) 331 | observed_data = observed_data-means 332 | stdev = torch.sqrt(torch.var(observed_data, dim=2, keepdim=True, unbiased=False) + 1e-5) 333 | observed_data /= stdev 334 | 335 | x_co = (observed_mask * observed_data).unsqueeze(1) 336 | mts_emb = self.get_mts_emb(observed_tp, observed_mask, x_co, feature_id) 337 | 338 | if task == 'classification': 339 | outputs = self.mlp(mts_emb.reshape(mts_emb.shape[0],-1)) 340 | classes = classes.to(self.device) 341 | probabilities = F.softmax(outputs, dim=1) 342 | _, predicted_classes = torch.max(probabilities, 1) 343 | correct = (predicted_classes == torch.tensor(classes)).sum().item() 344 | total = classes.size(0) 345 | accuracy = correct/total 346 | #print(probabilities.cpu().numpy()) 347 | #auc = roc_auc_score(classes.cpu().numpy(), probabilities[:, 1].cpu().numpy()) 348 | #print('AUC', auc) 349 | return (outputs, classes), (correct, total) 350 | elif task == 'anomaly_detection': 351 | B, C, K, L =mts_emb.shape 352 | outputs = self.conv(mts_emb[:, :C-1, :, :].reshape(B, (C-1)*K, L).permute(0,2,1)).permute(0,2,1) 353 | #outputs = self.projection(mts_emb.permute(0,2,3,1)).squeeze(-1) 354 | if normalize_for_ad: 355 | dec_out = outputs * \ 356 | (stdev[:, :, 0].unsqueeze(2).repeat( 357 | 1, 1, L)) 358 | outputs = dec_out + \ 359 | (means[:, :, 0].unsqueeze(2).repeat( 360 | 1, 1, L)) 361 | score = torch.mean(criterion(original_observed_data, outputs), dim=1) 362 | score = score.detach().cpu().numpy() 363 | return outputs, score 364 | 365 | 366 | 367 | def evaluate(self, batch, n_samples, normalize_for_ad=False): 368 | ( 369 | observed_data, 370 | observed_mask, 371 | feature_id, 372 | observed_tp, 373 | gt_mask, 374 | _, 375 | cut_length, 376 | _, 377 | labels 378 | ) = self.process_data(batch, sample_feat=self.sample_feat, train=False) 379 | 380 | with torch.no_grad(): 381 | cond_mask = gt_mask 382 | target_mask = observed_mask - cond_mask 383 | 384 | if normalize_for_ad: 385 | ## Normalization from non-stationary Transformer 386 | original_observed_data = observed_data.clone() 387 | means = observed_data.mean(2, keepdim=True) 388 | observed_data = observed_data-means 389 | stdev = torch.sqrt(torch.var(observed_data, dim=2, keepdim=True, unbiased=False) + 1e-5) 390 | observed_data /= stdev 391 | 392 | x_co = (cond_mask * observed_data).unsqueeze(1) 393 | mts_emb = self.get_mts_emb(observed_tp, cond_mask, x_co, feature_id) 394 | 395 | samples = self.impute(observed_data, cond_mask, mts_emb, n_samples) 396 | 397 | for i in range(len(cut_length)): # to avoid double evaluation 398 | target_mask[i, ..., 0 : cut_length[i].item()] = 0 399 | if normalize_for_ad: 400 | if labels is not None: 401 | return samples, original_observed_data, target_mask, observed_mask, observed_tp, labels 402 | else: 403 | return samples, original_observed_data, target_mask, observed_mask, observed_tp 404 | else: 405 | if labels is not None: 406 | return samples, observed_data, target_mask, observed_mask, observed_tp, labels 407 | else: 408 | return samples, observed_data, target_mask, observed_mask, observed_tp 409 | 410 | 411 | 412 | class TSDE_Forecasting(TSDE_base): 413 | """ 414 | Specialized TSDE model for forecasting tasks. 415 | 416 | This class extends the TSDE_base model to specifically handle forecasting by processing the input data appropriately. 417 | """ 418 | def __init__(self, config, device, target_dim, sample_feat=False): 419 | super(TSDE_Forecasting, self).__init__(target_dim, config, device, sample_feat) 420 | 421 | def process_data(self, batch, sample_feat, train=True): 422 | observed_data = batch["observed_data"].to(self.device).float() 423 | observed_mask = batch["observed_mask"].to(self.device).float() 424 | observed_tp = batch["timepoints"].to(self.device).float() 425 | gt_mask = batch["gt_mask"].to(self.device).float() 426 | feature_id = batch["feature_id"].to(self.device).long() 427 | observed_data = observed_data.permute(0, 2, 1) 428 | observed_mask = observed_mask.permute(0, 2, 1) 429 | gt_mask = gt_mask.permute(0, 2, 1) 430 | if train and sample_feat: 431 | sampled_data = [] 432 | sampled_mask = [] 433 | sampled_feature_id = [] 434 | sampled_gt_mask = [] 435 | size = 128 436 | 437 | for i in range(len(observed_data)): 438 | ind = np.arange(feature_id.shape[1]) 439 | np.random.shuffle(ind) 440 | sampled_data.append(observed_data[i,ind[:size],:]) 441 | sampled_mask.append(observed_mask[i,ind[:size],:]) 442 | sampled_feature_id.append(feature_id[i,ind[:size]]) 443 | sampled_gt_mask.append(gt_mask[i,ind[:size],:]) 444 | observed_data = torch.stack(sampled_data,0) 445 | observed_mask = torch.stack(sampled_mask,0) 446 | feature_id = torch.stack(sampled_feature_id,0) 447 | gt_mask = torch.stack(sampled_gt_mask,0) 448 | 449 | cut_length = torch.zeros(len(observed_data)).long().to(self.device) 450 | for_pattern_mask = observed_mask 451 | 452 | return ( 453 | observed_data, 454 | observed_mask, 455 | feature_id, 456 | observed_tp, 457 | gt_mask, 458 | for_pattern_mask, 459 | cut_length, 460 | None, 461 | None, 462 | ) 463 | 464 | 465 | class TSDE_PM25(TSDE_base): 466 | """ 467 | Specialized TSDE model for PM2.5 environmental data. 468 | 469 | Designed to handle and process PM2.5 data for imputation. 470 | """ 471 | def __init__(self, config, device, target_dim=36, sample_feat=False): 472 | super(TSDE_PM25, self).__init__(target_dim, config, device, sample_feat) 473 | 474 | def process_data(self, batch, train, sample_feat): 475 | observed_data = batch["observed_data"].to(self.device).float() 476 | observed_mask = batch["observed_mask"].to(self.device).float() 477 | observed_tp = batch["timepoints"].to(self.device).float() 478 | gt_mask = batch["gt_mask"].to(self.device).float() 479 | cut_length = batch["cut_length"].to(self.device).long() 480 | for_pattern_mask = batch["hist_mask"].to(self.device).float() 481 | 482 | observed_data = observed_data.permute(0, 2, 1) 483 | observed_mask = observed_mask.permute(0, 2, 1) 484 | gt_mask = gt_mask.permute(0, 2, 1) 485 | for_pattern_mask = for_pattern_mask.permute(0, 2, 1) 486 | 487 | return ( 488 | observed_data, 489 | observed_mask, 490 | None, 491 | observed_tp, 492 | gt_mask, 493 | for_pattern_mask, 494 | cut_length, 495 | None, 496 | None, 497 | ) 498 | 499 | 500 | class TSDE_Physio(TSDE_base): 501 | """ 502 | Specialized TSDE model for PhysioNet dataset. 503 | 504 | Adapts the TSDE_base model for tasks involving PhysioNet data, including imputation and interpolation. 505 | """ 506 | def __init__(self, config, device, target_dim=35, sample_feat=False): 507 | super(TSDE_Physio, self).__init__(target_dim, config, device, sample_feat) 508 | 509 | def process_data(self, batch, train, sample_feat): 510 | observed_data = batch["observed_data"].to(self.device).float() 511 | observed_mask = batch["observed_mask"].to(self.device).float() 512 | observed_tp = batch["timepoints"].to(self.device).float() 513 | gt_mask = batch["gt_mask"].to(self.device).float() 514 | labels = batch["labels"] 515 | observed_data = observed_data.permute(0, 2, 1) 516 | observed_mask = observed_mask.permute(0, 2, 1) 517 | gt_mask = gt_mask.permute(0, 2, 1) 518 | 519 | cut_length = torch.zeros(len(observed_data)).long().to(self.device) 520 | for_pattern_mask = observed_mask 521 | 522 | return ( 523 | observed_data, 524 | observed_mask, 525 | None, 526 | observed_tp, 527 | gt_mask, 528 | for_pattern_mask, 529 | cut_length, 530 | None, 531 | labels, 532 | 533 | ) 534 | 535 | 536 | class TSDE_AD(TSDE_base): 537 | """ 538 | Specialized TSDE model for anomaly detection datasets. 539 | 540 | Tailors the TSDE_base model for anomaly detection datasets, including MSL, SMD, PSM, SMAP and SWaT. 541 | """ 542 | def __init__(self, config, device, target_dim=55, sample_feat=False): 543 | super(TSDE_AD, self).__init__(target_dim, config, device, sample_feat) 544 | 545 | def process_data(self, batch, train, sample_feat): 546 | observed_data = batch["observed_data"].to(self.device).float() 547 | observed_mask = batch["observed_mask"].to(self.device).float() 548 | observed_tp = batch["timepoints"].to(self.device).float() 549 | gt_mask = batch["gt_mask"].to(self.device).float() 550 | label = batch["label"].to(self.device).float() 551 | observed_data = observed_data.permute(0, 2, 1) 552 | observed_mask = observed_mask.permute(0, 2, 1) 553 | gt_mask = gt_mask.permute(0, 2, 1) 554 | 555 | cut_length = torch.zeros(len(observed_data)).long().to(self.device) 556 | for_pattern_mask = observed_mask 557 | return ( 558 | observed_data, 559 | observed_mask, 560 | None, 561 | observed_tp, 562 | gt_mask, 563 | for_pattern_mask, 564 | cut_length, 565 | None, 566 | label, 567 | ) 568 | 569 | --------------------------------------------------------------------------------