├── __init__.py ├── .gitattributes ├── data └── swat │ ├── test.csv │ ├── train.csv │ └── list.txt ├── util ├── time.py ├── net_struct.py ├── preprocess.py └── data.py ├── .gitignore ├── README.md ├── evaluate.py ├── train.py ├── datasets └── TimeDataset.py ├── test.py ├── models ├── graph_layer.py └── FuSAGNet.py └── main.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.csv filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /data/swat/test.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b9a5cc1f1fa11f259dcdb39c80889124fbab3c761765570cfe9557b4311f30fb 3 | size 131205212 4 | -------------------------------------------------------------------------------- /data/swat/train.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3ca54efc0c4300a5efe61dbad907944f7095fc259257a9ae4c33d4f5f6cbc0d3 3 | size 145065906 4 | -------------------------------------------------------------------------------- /data/swat/list.txt: -------------------------------------------------------------------------------- 1 | FIT101 2 | LIT101 3 | MV101 4 | P101 5 | P102 6 | AIT201 7 | AIT202 8 | AIT203 9 | FIT201 10 | MV201 11 | P201 12 | P202 13 | P203 14 | P204 15 | P205 16 | P206 17 | DPIT301 18 | FIT301 19 | LIT301 20 | MV301 21 | MV302 22 | MV303 23 | MV304 24 | P301 25 | P302 26 | AIT401 27 | AIT402 28 | FIT401 29 | LIT401 30 | P401 31 | P402 32 | P403 33 | P404 34 | UV401 35 | AIT501 36 | AIT502 37 | AIT503 38 | AIT504 39 | FIT501 40 | FIT502 41 | FIT503 42 | FIT504 43 | P501 44 | P502 45 | PIT501 46 | PIT502 47 | PIT503 48 | FIT601 49 | P601 50 | P602 51 | P603 52 | -------------------------------------------------------------------------------- /util/time.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | from datetime import datetime 4 | 5 | 6 | def asMinutes(s): 7 | m = math.floor(s / 60) 8 | s -= m * 60 9 | return "%dm %ds" % (m, s) 10 | 11 | 12 | def timeSincePlus(since, percent): 13 | now = time.time() 14 | s = now - since 15 | es = s / (percent) 16 | rs = es - s 17 | return "%s (- %s)" % (asMinutes(s), asMinutes(rs)) 18 | 19 | 20 | def timeSince(since): 21 | now = time.time() 22 | s = now - since 23 | m = math.floor(s / 60) 24 | s -= m * 60 25 | return "%dm %ds" % (m, s) 26 | 27 | 28 | def timestamp2str(sec, fmt, tz): 29 | return datetime.fromtimestamp(sec).astimezone(tz).strftime(fmt) 30 | -------------------------------------------------------------------------------- /util/net_struct.py: -------------------------------------------------------------------------------- 1 | def get_feature_map(dataset): 2 | feature_file = open(f"./data/{dataset}/list.txt", "r") 3 | feature_list = [] 4 | for ft in feature_file: 5 | feature_list.append(ft.strip()) 6 | 7 | return feature_list 8 | 9 | 10 | def get_fc_graph_struc(dataset): 11 | feature_file = open(f"./data/{dataset}/list.txt", "r") 12 | struc_map = {} 13 | feature_list = [] 14 | for ft in feature_file: 15 | feature_list.append(ft.strip()) 16 | 17 | for ft in feature_list: 18 | if ft not in struc_map: 19 | struc_map[ft] = [] 20 | 21 | for other_ft in feature_list: 22 | if other_ft is not ft: 23 | struc_map[ft].append(other_ft) 24 | 25 | return struc_map 26 | -------------------------------------------------------------------------------- /util/preprocess.py: -------------------------------------------------------------------------------- 1 | def construct_data(data, feature_map, labels=0): 2 | res = [] 3 | for feature in feature_map: 4 | if feature in data.columns: 5 | res.append(data.loc[:, feature].values.tolist()) 6 | else: 7 | print(feature, "not exist in data") 8 | 9 | sample_n = len(res[0]) 10 | if type(labels) == int: 11 | res.append([labels] * sample_n) 12 | elif len(labels) == sample_n: 13 | res.append(labels) 14 | 15 | return res 16 | 17 | 18 | def build_loc_net(struc, all_features, feature_map=[]): 19 | index_feature_map = feature_map 20 | edge_indexes = [[], []] 21 | for node_name, node_list in struc.items(): 22 | if node_name not in all_features: 23 | continue 24 | 25 | if node_name not in index_feature_map: 26 | index_feature_map.append(node_name) 27 | 28 | p_index = index_feature_map.index(node_name) 29 | for child in node_list: 30 | if child not in all_features: 31 | continue 32 | 33 | if child not in index_feature_map: 34 | print(f"Error: {child} not in index_feature_map") 35 | 36 | c_index = index_feature_map.index(child) 37 | edge_indexes[0].append(c_index) 38 | edge_indexes[1].append(p_index) 39 | 40 | return edge_indexes 41 | -------------------------------------------------------------------------------- /util/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy.stats import iqr, rankdata 5 | from sklearn.metrics import f1_score 6 | 7 | 8 | def generate_gaussian_noise(x, mean=0.0, std=1.0): 9 | eps = torch.randn(x.size()) * std + mean 10 | return eps 11 | 12 | 13 | def eval_scores(scores, true_scores, th_steps, return_threshold=False): 14 | padding_list = [0] * (len(true_scores) - len(scores)) 15 | if len(padding_list) > 0: 16 | scores = padding_list + scores 17 | 18 | scores_sorted = rankdata(scores, method="ordinal") 19 | th_vals = np.array(range(th_steps)) * 1.0 / th_steps 20 | fmeas = [None] * th_steps 21 | thresholds = [None] * th_steps 22 | for i in range(th_steps): 23 | cur_pred = scores_sorted > th_vals[i] * len(scores) 24 | fmeas[i] = f1_score(true_scores, cur_pred) 25 | score_index = scores_sorted.tolist().index(int(th_vals[i] * len(scores) + 1)) 26 | thresholds[i] = scores[score_index] 27 | 28 | if return_threshold: 29 | return fmeas, thresholds 30 | 31 | return fmeas 32 | 33 | 34 | def get_err_median_and_iqr(predicted, groundtruth): 35 | np_arr = np.subtract(np.array(predicted), np.array(groundtruth)) 36 | err_median = np.median(np_arr) 37 | err_iqr = iqr(np_arr) 38 | return err_median, err_iqr 39 | 40 | 41 | def kl_divergence(p, q): 42 | p = F.softmax(p, dim=1) 43 | q = F.softmax(q, dim=1) 44 | s1 = torch.sum(p * torch.log(p / q)) 45 | s2 = torch.sum((1 - p) * torch.log((1 - p) / (1 - q))) 46 | return s1 + s2 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Sparse Latent Graph Representations for Anomaly Detection in Multivariate Time Series (KDD '22) 2 | 3 | This repository contains the official PyTorch implementation* of Fused Sparse Autoencoder and Graph Net (FuSAGNet), introduced in ["Learning Sparse Latent Graph Representations for Anomaly Detection in Multivariate Time Series" (KDD '22)](https://dl.acm.org/doi/abs/10.1145/3534678.3539117). 4 | 5 | \*Partly based on the implementation of [GDN](https://github.com/d-ailin/GDN), introduced in ["Graph Neural Network-Based Anomaly Detection in Multivariate Time Series" (AAAI '21)](https://ojs.aaai.org/index.php/AAAI/article/view/16523). 6 | 7 | ## Repository Organization 8 | 9 | ├── data 10 | | └── swat 11 | | ├── list.txt 12 | | ├── test.csv 13 | | └── train.csv 14 | ├── datasets 15 | | └── TimeDataset.py 16 | ├── models 17 | | ├── FuSAGNet.py 18 | | └── graph_layer.py 19 | ├── util 20 | | ├── data.py 21 | | ├── net_struct.py 22 | | ├── preprocess.py 23 | | └── time.py 24 | ├── .gitattributes 25 | ├── .gitignore 26 | ├── README.md 27 | ├── __init__.py 28 | ├── evaluate.py 29 | ├── main.py 30 | ├── test.py 31 | └── train.py 32 | 33 | ## Requirements 34 | 35 | * Python >= 3.6 36 | * CUDA == 10.2 37 | * PyTorch == 1.5.0 38 | * PyTorch Geometric == 1.5.0 39 | 40 | ## Datasets 41 | 42 | This repository includes [SWaT](https://link.springer.com/chapter/10.1007/978-3-319-71368-7_8) as the default dataset (see the `data` directory). The [WADI](https://dl.acm.org/doi/abs/10.1145/3055366.3055375) dataest can be requested [here](https://itrust.sutd.edu.sg/itrust-labs_datasets/) and the [HAI](https://www.usenix.org/system/files/cset20-paper-shin.pdf) dataset can be downloaded [here](https://github.com/icsdataset/hai). 43 | 44 | ## Run 45 | 46 | You can run the code using the following command. 47 | 48 | ``` 49 | python main.py 50 | ``` 51 | 52 | ## Citation 53 | 54 | If you find our work useful, please consider citing our paper. 55 | 56 | ``` 57 | @inproceedings{han2022learning, 58 | title={Learning Sparse Latent Graph Representations for Anomaly Detection in Multivariate Time Series}, 59 | author={Han, Siho and Woo, Simon S}, 60 | booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 61 | pages={2977--2986}, 62 | year={2022} 63 | } 64 | ``` 65 | 66 | ## References 67 | * Han, Siho, and Simon S. Woo. "Learning Sparse Latent Graph Representations for Anomaly Detection in Multivariate Time Series." Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining. 2022. 68 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import precision_score, recall_score, roc_auc_score 3 | 4 | from util.data import * 5 | 6 | 7 | def get_full_err_scores(test_result, val_result): 8 | np_val_result = np.array(val_result) 9 | np_test_result = np.array(test_result) 10 | all_normals = None 11 | all_scores = None 12 | feature_num = np_test_result.shape[2] 13 | for i in range(feature_num): 14 | val_re_list = np_val_result[:2, :, i] 15 | test_re_list = np_test_result[:2, :, i] 16 | normal_dist = get_err_scores(val_re_list) 17 | scores = get_err_scores(test_re_list) 18 | 19 | if all_scores is None: 20 | all_normals = normal_dist 21 | all_scores = scores 22 | else: 23 | all_normals = np.vstack((all_normals, normal_dist)) 24 | all_scores = np.vstack((all_scores, scores)) 25 | 26 | return all_scores, all_normals 27 | 28 | 29 | def get_err_scores(test_res): 30 | test_predict, test_gt = test_res 31 | n_err_mid, n_err_iqr = get_err_median_and_iqr(test_predict, test_gt) 32 | test_delta = np.subtract( 33 | np.array(test_predict).astype(np.float64), np.array(test_gt).astype(np.float64) 34 | ) 35 | if len(test_delta.shape) >= 2: 36 | test_delta = np.max(test_delta, axis=1) 37 | 38 | epsilon = 1e-2 39 | err_scores = (test_delta - n_err_mid) / (np.abs(n_err_iqr) + epsilon) 40 | err_scores = np.abs(err_scores) 41 | smoothed_err_scores = np.zeros(err_scores.shape) 42 | before_num = 3 43 | for i in range(before_num, len(err_scores)): 44 | smoothed_err_scores[i] = np.mean(err_scores[i - before_num : i + 1]) 45 | 46 | return smoothed_err_scores 47 | 48 | 49 | def get_best_performance_data(total_err_scores, gt_labels, topk=1): 50 | total_features = total_err_scores.shape[0] 51 | topk_indices = np.argpartition( 52 | total_err_scores, range(total_features - topk - 1, total_features), axis=0 53 | )[-topk:] 54 | total_topk_err_scores = np.sum( 55 | np.take_along_axis(total_err_scores, topk_indices, axis=0), axis=0 56 | ) 57 | final_topk_fmeas, thresholds = eval_scores( 58 | total_topk_err_scores, gt_labels, 400, return_threshold=True 59 | ) 60 | 61 | th_i = final_topk_fmeas.index(max(final_topk_fmeas)) 62 | threshold = thresholds[th_i] 63 | pred_labels = np.zeros(len(total_topk_err_scores)) 64 | pred_labels[total_topk_err_scores > threshold] = 1 65 | for i in range(len(pred_labels)): 66 | pred_labels[i] = int(pred_labels[i]) 67 | gt_labels[i] = int(gt_labels[i]) 68 | 69 | pre = precision_score(gt_labels, pred_labels) 70 | rec = recall_score(gt_labels, pred_labels) 71 | auc_score = roc_auc_score(gt_labels, total_topk_err_scores) 72 | return max(final_topk_fmeas), pre, rec, auc_score, threshold 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from test import * 7 | from util.data import * 8 | from util.time import * 9 | 10 | 11 | def train( 12 | model=None, 13 | save_path="", 14 | config={}, 15 | train_dataloader=None, 16 | val_dataloader=None, 17 | device=None, 18 | test_dataloader=None, 19 | test_dataset=None, 20 | dataset_name="swat", 21 | train_dataset=None, 22 | ): 23 | model.train() 24 | 25 | optimizer = torch.optim.Adam( 26 | model.parameters(), lr=config["lr"], weight_decay=config["decay"] 27 | ) 28 | 29 | alpha = config["alpha"] 30 | beta = config["beta"] 31 | epochs = config["epoch"] 32 | 33 | patience = epochs // 5 34 | epochs_improved = 0 35 | 36 | train_loss_list = [] 37 | min_loss = sys.float_info.max 38 | reduction = "mean" 39 | for epoch in range(epochs): 40 | model.train() 41 | 42 | total_loss = 0 43 | for ( 44 | x, 45 | y, 46 | _, 47 | edge_index, 48 | ) in train_dataloader: 49 | x = torch.add(x, generate_gaussian_noise(x)) 50 | x, y, edge_index = [item.float().to(device) for item in [x, y, edge_index]] 51 | 52 | optimizer.zero_grad() 53 | 54 | ( 55 | x_hat, 56 | x_recon, 57 | _, 58 | mu, 59 | log_var, 60 | _, 61 | _, 62 | _, 63 | _, 64 | rhos, 65 | rho_hat, 66 | ) = model(x, y, edge_index) 67 | 68 | x_hat = x_hat.float().to(device) 69 | x_recon = x_recon.float().to(device) 70 | if (mu is not None) and (log_var is not None): 71 | mu = mu.float().to(device) 72 | log_var = log_var.float().to(device) 73 | 74 | loss_frcst = torch.sqrt(F.mse_loss(x_hat, y, reduction=reduction)) 75 | loss_recon = F.mse_loss(x_recon, x, reduction=reduction) 76 | loss_recon += beta * kl_divergence(rhos, rho_hat) 77 | loss = alpha * loss_frcst + (1.0 - alpha) * loss_recon 78 | 79 | loss.backward() 80 | optimizer.step() 81 | 82 | total_loss += loss.item() 83 | train_loss_list.append(total_loss) 84 | 85 | train_loss_log = f"F: {alpha * loss_frcst:.4f} | R: {(1.0 - alpha) * loss_recon:.4f} | beta*KLD: {beta * kl_divergence(rhos, rho_hat)}" 86 | if val_dataloader is not None: 87 | val_loss, _ = test(model, val_dataloader, device, config=config) 88 | val_loss_log = f"V: {val_loss:.4f}" 89 | loss_log = f"[E {epoch + 1}/{epochs}] " + " | ".join( 90 | [val_loss_log, train_loss_log] 91 | ) 92 | print(loss_log, flush=True) 93 | if val_loss < min_loss: 94 | torch.save(model.state_dict(), save_path) 95 | min_loss = val_loss 96 | epochs_improved = 0 97 | else: 98 | epochs_improved += 1 99 | 100 | if epochs_improved >= patience: 101 | break 102 | 103 | return train_loss_list 104 | -------------------------------------------------------------------------------- /datasets/TimeDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class TimeDataset(Dataset): 6 | def __init__( 7 | self, 8 | raw_data, 9 | edge_index, 10 | mode="train", 11 | task="reconstruction", 12 | min_train=None, 13 | max_train=None, 14 | config=None, 15 | preprocess=None, 16 | ): 17 | self.raw_data = raw_data 18 | self.edge_index = edge_index 19 | self.mode = mode 20 | self.task = task 21 | self.config = config 22 | 23 | x_data = raw_data[:-1] 24 | labels = raw_data[-1] 25 | data = torch.tensor(x_data).double() 26 | if mode == "train": 27 | min_train, max_train = [], [] 28 | for i in range(data.size(0)): 29 | col_min, col_max = torch.min(data[i]), torch.max(data[i]) 30 | min_train.append(col_min) 31 | max_train.append(col_max) 32 | if (col_min == 0.0) and (col_max == 0.0): 33 | pass 34 | elif (col_min != 0.0) and (col_max == col_min): 35 | data[i] /= min_train[-1] 36 | else: 37 | data[i] -= min_train[-1] 38 | data[i] /= max_train[-1] - min_train[-1] 39 | 40 | self.min_train, self.max_train = min_train, max_train 41 | 42 | elif mode == "test": 43 | for i in range(data.size(0)): 44 | col_min, col_max = min_train[i], max_train[i] 45 | if (col_min == 0.0) and (col_max == 0.0): 46 | pass 47 | elif (col_min != 0.0) and (col_max == col_min): 48 | data[i] /= col_min 49 | else: 50 | data[i] -= col_min 51 | data[i] /= col_max - col_min 52 | 53 | labels = torch.tensor(labels).double() 54 | self.x, self.y, self.labels = self.process(data, labels) 55 | 56 | def get_train_min_max(self): 57 | return self.min_train, self.max_train 58 | 59 | def get_train_mean_std(self): 60 | return self.means, self.stds 61 | 62 | def process(self, data, labels): 63 | x_arr, labels_arr = [], [] 64 | y_arr = None if self.task == "reconstruction" else [] 65 | slide_win, slide_stride = [ 66 | self.config[k] for k in ("slide_win", "slide_stride") 67 | ] 68 | is_train = self.mode == "train" 69 | _, total_time_len = data.shape 70 | rang = ( 71 | range(slide_win, total_time_len, slide_stride) 72 | if is_train 73 | else range(slide_win, total_time_len) 74 | ) 75 | for i in rang: 76 | window = data[:, i - slide_win : i] 77 | if y_arr is not None: 78 | target = data[:, i] 79 | y_arr.append(target) 80 | 81 | x_arr.append(window) 82 | labels_arr.append(labels[i]) 83 | 84 | x = torch.stack(x_arr).contiguous() 85 | y = torch.stack(y_arr).contiguous() if y_arr is not None else None 86 | labels = torch.Tensor(labels_arr).contiguous() 87 | return x, y, labels 88 | 89 | def __len__(self): 90 | return len(self.x) 91 | 92 | def __getitem__(self, idx): 93 | window = self.x[idx].double() 94 | window_y = self.y[idx].double() if self.y is not None else None 95 | label = self.labels[idx].double() 96 | edge_index = self.edge_index.long() 97 | return window, window_y, label, edge_index 98 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from util.data import * 7 | from util.preprocess import * 8 | from util.time import * 9 | 10 | 11 | def test(model, dataloader, device, config={}): 12 | now = time.time() 13 | model.eval() 14 | 15 | test_len = len(dataloader) 16 | y_hat_list = [] 17 | y_list = [] 18 | y_label_list = [] 19 | x_hat_list = [] 20 | x_list = [] 21 | x_label_list = [] 22 | 23 | i = 0 24 | alpha = config["alpha"] 25 | beta = config["beta"] 26 | 27 | test_loss_list = [] 28 | total_loss = 0 29 | reduction = "mean" 30 | for ( 31 | x, 32 | y, 33 | labels, 34 | edge_index, 35 | ) in dataloader: 36 | x, y, edge_index = [item.float().to(device) for item in [x, y, edge_index]] 37 | with torch.no_grad(): 38 | ( 39 | x_hat, 40 | x_recon, 41 | _, 42 | mu, 43 | log_var, 44 | _, 45 | _, 46 | _, 47 | _, 48 | rhos, 49 | rho_hat, 50 | ) = model(x, y, edge_index) 51 | 52 | x_hat = x_hat.float().to(device) 53 | x_recon = x_recon.float().to(device) 54 | if (mu is not None) and (log_var is not None): 55 | mu = mu.float().to(device) 56 | log_var = log_var.float().to(device) 57 | 58 | loss_frcst = torch.sqrt(F.mse_loss(x_hat, y, reduction=reduction)) 59 | loss_recon = F.mse_loss(x_recon, x, reduction=reduction) 60 | loss_recon += beta * kl_divergence(rhos, rho_hat) 61 | 62 | predicted_y = x_hat 63 | y_labels = labels.unsqueeze(1).repeat(1, predicted_y.shape[1]) 64 | if len(y_hat_list) <= 0: 65 | y_hat_list = predicted_y 66 | y_list = y 67 | y_label_list = y_labels 68 | else: 69 | y_hat_list = torch.cat((y_hat_list, predicted_y), dim=0) 70 | y_list = torch.cat((y_list, y), dim=0) 71 | y_label_list = torch.cat((y_label_list, y_labels), dim=0) 72 | 73 | predicted_x = x_recon 74 | x_labels = labels.unsqueeze(1).repeat(1, predicted_x.shape[1]) 75 | if len(x_hat_list) <= 0: 76 | x_hat_list = predicted_x 77 | x_list = x 78 | x_label_list = x_labels 79 | else: 80 | x_hat_list = torch.cat((x_hat_list, predicted_x), dim=0) 81 | x_list = torch.cat((x_list, x), dim=0) 82 | x_label_list = torch.cat((x_label_list, x_labels), dim=0) 83 | 84 | loss = alpha * loss_frcst + (1.0 - alpha) * loss_recon 85 | 86 | total_loss += loss.item() 87 | test_loss_list.append(loss.item()) 88 | 89 | i += 1 90 | if i % 10000 == 1 and i > 1: 91 | print(timeSincePlus(now, i / test_len)) 92 | 93 | y_hat_list = y_hat_list.tolist() 94 | y_list = y_list.tolist() 95 | y_label_list = y_label_list.tolist() 96 | x_hat_list = x_hat_list.tolist() 97 | x_list = x_list.tolist() 98 | x_label_list = x_label_list.tolist() 99 | val_loss = sum(test_loss_list) / len(test_loss_list) 100 | return val_loss, { 101 | "forecasting": [y_hat_list, y_list, y_label_list], 102 | "reconstruction": [x_hat_list, x_list, x_label_list], 103 | } 104 | -------------------------------------------------------------------------------- /models/graph_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Linear, Parameter 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.nn.inits import glorot, zeros 6 | from torch_geometric.utils import add_self_loops, remove_self_loops, softmax 7 | 8 | 9 | class GraphLayer(MessagePassing): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | heads=1, 15 | concat=True, 16 | negative_slope=0.2, 17 | dropout=0.0, 18 | bias=True, 19 | **kwargs 20 | ): 21 | super(GraphLayer, self).__init__(aggr="add", **kwargs) 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.heads = heads 26 | self.concat = concat 27 | self.negative_slope = negative_slope 28 | self.dropout = dropout 29 | self.__alpha__ = None 30 | 31 | self.lin = Linear(in_channels, heads * out_channels, bias=False) 32 | if bias and concat: 33 | self.bias = Parameter(torch.Tensor(heads * out_channels)) 34 | elif bias and (not concat): 35 | self.bias = Parameter(torch.Tensor(out_channels)) 36 | else: 37 | self.register_parameter("bias", None) 38 | 39 | self.att_i = Parameter(torch.Tensor(1, heads, out_channels)) 40 | self.att_j = Parameter(torch.Tensor(1, heads, out_channels)) 41 | self.att_em_i = Parameter(torch.Tensor(1, heads, out_channels)) 42 | self.att_em_j = Parameter(torch.Tensor(1, heads, out_channels)) 43 | 44 | self.reset_parameters() 45 | 46 | def reset_parameters(self): 47 | glorot(self.lin.weight) 48 | glorot(self.att_i) 49 | glorot(self.att_j) 50 | zeros(self.att_em_i) 51 | zeros(self.att_em_j) 52 | zeros(self.bias) 53 | 54 | def forward(self, x, edge_index, embedding, return_attention_weights=False): 55 | # Add self-loops to the adjacency matrix 56 | edge_index, _ = remove_self_loops(edge_index) 57 | edge_index, _ = add_self_loops(edge_index, num_nodes=x[1].size(self.node_dim)) 58 | 59 | # Linearly transform node feature matrix 60 | if torch.is_tensor(x): 61 | x = self.lin(x) 62 | x = (x, x) 63 | else: 64 | x = (self.lin(x[0]), self.lin(x[1])) 65 | 66 | # Start propagating messages 67 | out = self.propagate( 68 | edge_index=edge_index, 69 | x=x, 70 | embedding=embedding, 71 | edges=edge_index, 72 | return_attention_weights=return_attention_weights, 73 | ) 74 | if self.concat: 75 | out = out.view(-1, self.heads * self.out_channels) 76 | else: 77 | out = out.mean(dim=1) 78 | 79 | if self.bias is not None: 80 | out = out + self.bias 81 | 82 | if return_attention_weights: 83 | alpha, self.__alpha__ = self.__alpha__, None 84 | return out, (edge_index, alpha) 85 | else: 86 | return out 87 | 88 | def message( 89 | self, x_i, x_j, edge_index_i, size_i, embedding, edges, return_attention_weights 90 | ): 91 | x_i = x_i.view(-1, self.heads, self.out_channels) # Target node 92 | x_j = x_j.view(-1, self.heads, self.out_channels) # Source node 93 | 94 | if embedding is not None: 95 | embedding_i, embedding_j = ( 96 | embedding[edge_index_i], # edge_index_i = edges[1], i.e., parent nodes 97 | embedding[edges[0]], # edges[0], i.e., child nodes 98 | ) 99 | embedding_i = embedding_i.unsqueeze(1).repeat(1, self.heads, 1) 100 | embedding_j = embedding_j.unsqueeze(1).repeat(1, self.heads, 1) 101 | 102 | # Eq. 6 in Deng and Hooi (2021) 103 | key_i = torch.cat((x_i, embedding_i), dim=-1) 104 | key_j = torch.cat((x_j, embedding_j), dim=-1) 105 | 106 | cat_att_i = torch.cat((self.att_i, self.att_em_i), dim=-1) 107 | cat_att_j = torch.cat((self.att_j, self.att_em_j), dim=-1) 108 | 109 | # Eq. 7 in Deng and Hooi (2021) 110 | attended_key_i = torch.einsum("nhd,mhd->nhd", key_i, cat_att_i) 111 | attended_key_j = torch.einsum("nhd,mhd->nhd", key_j, cat_att_j) 112 | alpha = attended_key_i.sum(dim=-1) + attended_key_j.sum(dim=-1) 113 | alpha = alpha.view(-1, self.heads, 1) 114 | alpha = F.leaky_relu(alpha, self.negative_slope) 115 | 116 | # Eq. 8 in Deng and Hooi (2021) 117 | alpha = softmax(src=alpha, index=edge_index_i, num_nodes=size_i) 118 | if self.dropout > 0: 119 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 120 | if return_attention_weights: 121 | self.__alpha__ = alpha 122 | 123 | return torch.einsum("nhc,nhd->nhd", alpha, x_j) 124 | 125 | def __repr__(self): 126 | return "{}({}, {}, heads={})".format( 127 | self.__class__.__name__, self.in_channels, self.out_channels, self.heads 128 | ) 129 | -------------------------------------------------------------------------------- /models/FuSAGNet.py: -------------------------------------------------------------------------------- 1 | from random import uniform 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from util.time import * 8 | from .graph_layer import GraphLayer 9 | 10 | 11 | def get_batch_edge_index(org_edge_index, batch_num, node_num): 12 | edge_index = org_edge_index.clone().detach() 13 | edge_num = org_edge_index.shape[1] 14 | batch_edge_index = edge_index.repeat(1, batch_num).contiguous() 15 | for i in range(batch_num): 16 | batch_edge_index[:, i * edge_num : (i + 1) * edge_num] += i * node_num 17 | 18 | return batch_edge_index.long() 19 | 20 | 21 | class OutLayer(nn.Module): 22 | def __init__(self, in_size, out_size, layer_num): 23 | super(OutLayer, self).__init__() 24 | modules = [] 25 | for i in range(layer_num): 26 | if i == layer_num - 1: 27 | modules.append(nn.Linear(in_size if layer_num == 1 else out_size, 1)) 28 | else: 29 | layer_in_num = in_size if i == 0 else out_size 30 | modules.append(nn.Linear(layer_in_num, out_size)) 31 | modules.append(nn.BatchNorm1d(out_size)) 32 | modules.append(nn.ReLU()) 33 | 34 | self.mlp = nn.ModuleList(modules) 35 | 36 | def forward(self, x): 37 | out = x 38 | for module in self.mlp: 39 | if isinstance(module, nn.BatchNorm1d): 40 | out = out.permute(0, 2, 1) 41 | out = module(out) 42 | out = out.permute(0, 2, 1) 43 | else: 44 | out = module(out) 45 | 46 | return out 47 | 48 | 49 | class SparseEncoder(nn.Module): 50 | def __init__( 51 | self, in_size, latent_size, num_layers, use_bn=False, use_dropout=False 52 | ): 53 | super().__init__() 54 | self.in_size = in_size 55 | self.latent_size = latent_size 56 | layer_sizes = [in_size] + [ 57 | in_size - (i + 1) * (in_size - latent_size) // num_layers 58 | for i in range(num_layers) 59 | ] 60 | self.encoder = nn.ModuleList( 61 | [ 62 | nn.Linear(layer_sizes[i], layer_sizes[i + 1]) 63 | for i in range(len(layer_sizes) - 1) 64 | ] 65 | ) 66 | 67 | self.use_bn = use_bn 68 | self.use_dropout = use_dropout 69 | if use_bn: 70 | self.bns = nn.ModuleList( 71 | [ 72 | nn.BatchNorm1d(self.encoder[i].out_features) 73 | for i in range(len(self.encoder)) 74 | ] 75 | ) 76 | 77 | def forward(self, x): 78 | B = x.size(0) 79 | if x.dim() > 2: 80 | x = x.view(B, -1) 81 | 82 | feature_maps = [] 83 | for i, module in enumerate(self.encoder): 84 | x = module(x) 85 | if self.use_bn: 86 | x = self.bns[i](x) 87 | 88 | x = torch.sigmoid(x) 89 | feature_maps.append(x) 90 | 91 | return x, feature_maps 92 | 93 | 94 | class SparseDecoder(nn.Module): 95 | def __init__( 96 | self, 97 | latent_size, 98 | out_size, 99 | n_features, 100 | num_layers, 101 | use_bn=False, 102 | use_dropout=False, 103 | ): 104 | super().__init__() 105 | self.latent_size = latent_size 106 | self.out_size = out_size 107 | layer_sizes = [out_size] + [ 108 | out_size - (i + 1) * (out_size - latent_size) // num_layers 109 | for i in range(num_layers) 110 | ] 111 | layer_sizes = layer_sizes[::-1] 112 | self.decoder = nn.ModuleList( 113 | [ 114 | nn.Linear(layer_sizes[i], layer_sizes[i + 1]) 115 | for i in range(len(layer_sizes) - 1) 116 | ] 117 | ) 118 | 119 | self.use_bn = use_bn 120 | self.use_dropout = use_dropout 121 | if use_bn: 122 | self.bns = nn.ModuleList( 123 | [ 124 | nn.BatchNorm1d(self.decoder[i].out_features) 125 | for i in range(len(self.decoder)) 126 | ] 127 | ) 128 | 129 | def forward(self, x): 130 | feature_maps = [] 131 | for i, module in enumerate(self.decoder): 132 | x = module(x) 133 | if self.use_bn: 134 | x = self.bns[i](x) 135 | 136 | x = torch.sigmoid(x) 137 | feature_maps.append(x) 138 | 139 | return x, feature_maps 140 | 141 | 142 | class GNNLayer(nn.Module): 143 | def __init__(self, in_channel, out_channel, heads=1): 144 | super(GNNLayer, self).__init__() 145 | self.gnn = GraphLayer( 146 | in_channels=in_channel, out_channels=out_channel, heads=heads, concat=False 147 | ) 148 | self.bn = nn.BatchNorm1d(out_channel) 149 | self.relu = nn.ReLU() 150 | 151 | def forward(self, x, edge_index, embedding=None): 152 | out, (_, _) = self.gnn(x, edge_index, embedding, return_attention_weights=True) 153 | out = self.bn(out) 154 | return self.relu(out) 155 | 156 | 157 | class FuSAGNet(nn.Module): 158 | def __init__( 159 | self, 160 | edge_index_sets, 161 | node_num, 162 | dim=16, 163 | out_layer_inter_dim=16, 164 | window_size=16, 165 | out_layer_num=1, 166 | topk=15, 167 | latent_size=16, 168 | n_layers=2, 169 | process_dict=None, 170 | ): 171 | super(FuSAGNet, self).__init__() 172 | self.edge_index_sets = edge_index_sets 173 | self.embed_dim = dim 174 | self.node_num = node_num 175 | sensor_f = 0 176 | embedding_modules = [] 177 | for process in process_dict: 178 | sensor_i = sensor_f 179 | n_processes = process_dict.get(process) 180 | sensor_f += n_processes 181 | embedding_modules.append(nn.Embedding(sensor_f - sensor_i, self.embed_dim)) 182 | 183 | self.embeddings = nn.ModuleList(embedding_modules) 184 | 185 | n_rnn_layers = 3 186 | self.rnn_embedding_modules = nn.ModuleList( 187 | [ 188 | nn.GRU( 189 | self.embed_dim, 190 | self.embed_dim // 2, 191 | bidirectional=True, 192 | num_layers=n_rnn_layers, 193 | dropout=0.2, 194 | ) 195 | for _ in range(len(process_dict)) 196 | ] 197 | ) 198 | self.bn_outlayer_in = nn.BatchNorm1d(self.embed_dim) 199 | 200 | edge_set_num = len(edge_index_sets) 201 | 202 | self.topk = topk 203 | self.learned_graph = None 204 | 205 | self.latent_size = window_size 206 | num_layers = n_layers 207 | self.encoder = SparseEncoder( 208 | in_size=node_num * window_size, 209 | latent_size=node_num * self.latent_size, 210 | num_layers=num_layers, 211 | use_bn=True, 212 | use_dropout=True, 213 | ) 214 | self.decoder = SparseDecoder( 215 | latent_size=node_num * self.latent_size, 216 | out_size=node_num * window_size, 217 | n_features=node_num, 218 | num_layers=num_layers, 219 | use_bn=True, 220 | use_dropout=True, 221 | ) 222 | self.gnn_layers = nn.ModuleList( 223 | [ 224 | GNNLayer(in_channel=self.latent_size, out_channel=dim, heads=1) 225 | for _ in range(edge_set_num) 226 | ] 227 | ) 228 | self.out_layer = OutLayer( 229 | in_size=dim * edge_set_num, 230 | out_size=out_layer_inter_dim, 231 | layer_num=out_layer_num, 232 | ) 233 | 234 | self.dp = nn.Dropout(0.2) 235 | self.init_params(init_method="kaiming_uniform") 236 | 237 | def init_params(self, init_method): 238 | for embedding in self.embeddings: 239 | if init_method == "uniform": 240 | nn.init.uniform_(embedding.weight, a=0.0, b=1.0) 241 | elif init_method == "kaiming_uniform": 242 | nn.init.kaiming_uniform_(embedding.weight, a=0.0) 243 | elif init_method == "xavier_uniform": 244 | nn.init.xavier_uniform_(embedding.weight, gain=1.0) 245 | elif init_method == "normal": 246 | nn.init.normal_(embedding.weight, mean=0.0, std=1.0) 247 | elif init_method == "kaiming_normal": 248 | nn.init.kaiming_normal_(embedding.weight, a=0.0) 249 | elif init_method == "xavier_normal": 250 | nn.init.xavier_normal_(embedding.weight, gain=1.0) 251 | 252 | def sampling(self, mu, log_var): 253 | std = torch.exp(0.5 * log_var) 254 | eps = torch.randn_like(std) 255 | return eps.mul(std).add_(mu) 256 | 257 | def forward(self, data, target, org_edge_index): 258 | x = data.clone().detach() 259 | edge_index_sets = self.edge_index_sets 260 | device = data.device 261 | batch_num, node_num, _ = x.shape 262 | mu, log_var = None, None 263 | 264 | z, enc_feature_maps = self.encoder(x) 265 | z_out = z.view(-1, self.node_num, self.latent_size).clone().detach() 266 | x_recon, dec_feature_maps = self.decoder(z) 267 | x_recon = x_recon.view(x.size()) 268 | rhos = ( 269 | torch.FloatTensor([uniform(1e-5, 1e-2) for _ in range(z.size(1))]) 270 | .unsqueeze(0) 271 | .to(device) 272 | ) 273 | rho_hat = torch.sum(z, dim=0, keepdim=True) 274 | enc_fmaps, dec_fmaps = ( 275 | enc_feature_maps[:-1], 276 | dec_feature_maps[:-1][::-1], 277 | ) 278 | z = z.view(-1, self.latent_size).contiguous() 279 | 280 | gcn_outs = [] 281 | for i, _ in enumerate(edge_index_sets): 282 | embedded_sensors = [] 283 | y_process = [] 284 | for j, embedding in enumerate(self.embeddings): 285 | sensors = torch.arange(embedding.num_embeddings).to(device) 286 | embedded = embedding(sensors) 287 | embedded = embedded.unsqueeze(0) 288 | embedded, _ = self.rnn_embedding_modules[j](embedded) 289 | embedded = embedded.squeeze() 290 | embedded_sensors.append(embedded) 291 | y_process.extend(batch_num * [j for _ in range(embedded.size(0))]) 292 | 293 | y_process = torch.tensor(y_process).to(device) 294 | all_embeddings = torch.cat(embedded_sensors) 295 | embeds = all_embeddings.view(node_num, -1) 296 | all_embeddings = all_embeddings.repeat(batch_num, 1) 297 | 298 | cos_ji_mat = torch.matmul(embeds, embeds.T) 299 | cos_ji_mat /= torch.matmul( 300 | embeds.norm(dim=-1).view(-1, 1), embeds.norm(dim=-1).view(1, -1) 301 | ) 302 | topk_num = self.topk 303 | topk_indices_ji = torch.topk(cos_ji_mat, topk_num, dim=-1)[1] 304 | 305 | gated_i = ( 306 | torch.arange(0, node_num) 307 | .unsqueeze(1) 308 | .repeat(1, topk_num) 309 | .flatten() 310 | .to(device) 311 | .unsqueeze(0) 312 | ) 313 | gated_j = topk_indices_ji.flatten().unsqueeze(0) 314 | gated_edge_index = torch.cat((gated_j, gated_i), dim=0) 315 | batch_gated_edge_index = get_batch_edge_index( 316 | gated_edge_index, batch_num, node_num 317 | ).to(device) 318 | 319 | gcn_out = self.gnn_layers[i]( 320 | x=z, edge_index=batch_gated_edge_index, embedding=all_embeddings 321 | ) 322 | gcn_outs.append(gcn_out) 323 | 324 | x = torch.cat(gcn_outs, dim=1) 325 | x = x.view(batch_num, node_num, -1) 326 | 327 | out = torch.mul(x, embeds) 328 | out = out.permute(0, 2, 1) 329 | out = F.relu(self.bn_outlayer_in(out)) 330 | out = out.permute(0, 2, 1) 331 | out = self.dp(out) 332 | out = self.out_layer(out) 333 | out = out.view(-1, node_num) 334 | x_frcst = out 335 | return ( 336 | x_frcst, 337 | x_recon, 338 | z_out, 339 | mu, 340 | log_var, 341 | enc_fmaps, 342 | dec_fmaps, 343 | all_embeddings, 344 | y_process, 345 | rhos, 346 | rho_hat, 347 | ) 348 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import os 4 | import random 5 | from datetime import datetime 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from pytz import timezone 12 | from torch.utils.data import DataLoader, Subset 13 | 14 | from datasets.TimeDataset import TimeDataset 15 | from evaluate import ( 16 | get_best_performance_data, 17 | get_full_err_scores, 18 | ) 19 | from models.FuSAGNet import FuSAGNet 20 | from train import train 21 | from test import test 22 | from util.net_struct import get_fc_graph_struc, get_feature_map 23 | from util.preprocess import build_loc_net, construct_data 24 | 25 | 26 | class Main: 27 | def __init__(self, train_config, env_config, debug=False): 28 | self.train_config = train_config 29 | self.env_config = env_config 30 | self.datestr = None 31 | 32 | dataset = self.env_config["dataset"] 33 | train_orig = pd.read_csv(f"./data/{dataset}/train.csv", sep=",", index_col=0) 34 | test_orig = pd.read_csv(f"./data/{dataset}/test.csv", sep=",", index_col=0) 35 | if dataset in ["swat", "wadi"]: 36 | train, test = ( 37 | train_orig[2160:], 38 | test_orig, 39 | ) 40 | else: 41 | train, test = train_orig, test_orig 42 | 43 | if "attack" in train.columns: 44 | train = train.drop(columns=["attack"]) 45 | 46 | feature_map = get_feature_map(dataset) 47 | fc_struc = get_fc_graph_struc(dataset) 48 | 49 | self.device = torch.device( 50 | f'cuda:{train_config["gpu_id"]}' if torch.cuda.is_available() else "cpu" 51 | ) 52 | torch.cuda.set_device(self.device) 53 | 54 | fc_edge_index = build_loc_net( 55 | fc_struc, list(train.columns), feature_map=feature_map 56 | ) 57 | fc_edge_index = torch.tensor(fc_edge_index, dtype=torch.long) 58 | 59 | train_dataset_indata = construct_data(train, feature_map, labels=0) 60 | test_dataset_indata = construct_data( 61 | test, feature_map, labels=test.attack.tolist() 62 | ) 63 | 64 | cfg = { 65 | "slide_win": train_config["slide_win"], 66 | "slide_stride": train_config["slide_stride"], 67 | } 68 | 69 | train_dataset = TimeDataset( 70 | train_dataset_indata, 71 | fc_edge_index, 72 | mode="train", 73 | task="forecasting", 74 | config=cfg, 75 | ) 76 | min_train, max_train = train_dataset.get_train_min_max() 77 | test_dataset = TimeDataset( 78 | test_dataset_indata, 79 | fc_edge_index, 80 | mode="test", 81 | task="forecasting", 82 | min_train=min_train, 83 | max_train=max_train, 84 | config=cfg, 85 | ) 86 | 87 | train_dataloader, val_dataloader = self.get_loaders( 88 | train_dataset, train_config["batch"], val_ratio=train_config["val_ratio"] 89 | ) 90 | 91 | self.train_dataset = train_dataset 92 | self.test_dataset = test_dataset 93 | 94 | self.train_dataloader = train_dataloader 95 | self.val_dataloader = val_dataloader 96 | self.test_dataloader = DataLoader( 97 | test_dataset, 98 | batch_size=train_config["batch"], 99 | shuffle=False, 100 | num_workers=4, 101 | pin_memory=True, 102 | ) 103 | 104 | if env_config["dataset"] == "hai": 105 | process_dict = {"P1": 38, "P2": 22, "P3": 7, "P4": 12} 106 | elif env_config["dataset"] == "swat": 107 | process_dict = {"P1": 5, "P2": 11, "P3": 9, "P4": 9, "P5": 13, "P6": 4} 108 | elif env_config["dataset"] == "wadi": 109 | process_dict = {"P1": 19, "P2": 90, "P3": 15, "P4": 3} 110 | 111 | edge_index_sets = [fc_edge_index] 112 | self.model = FuSAGNet( 113 | edge_index_sets=edge_index_sets, 114 | node_num=len(feature_map), 115 | dim=train_config["dim"], 116 | window_size=train_config["slide_win"], 117 | out_layer_num=train_config["out_layer_num"], 118 | out_layer_inter_dim=train_config["out_layer_inter_dim"], 119 | topk=train_config["topk"], 120 | process_dict=process_dict, 121 | ).to(self.device) 122 | 123 | def run(self): 124 | if len(self.env_config["load_model_path"]) > 0: 125 | model_save_path = self.env_config["load_model_path"] 126 | else: 127 | model_save_path = self.get_save_path()[0] 128 | self.train_log = train( 129 | self.model, 130 | save_path=model_save_path, 131 | config=train_config, 132 | train_dataloader=self.train_dataloader, 133 | val_dataloader=self.val_dataloader, 134 | device=self.device, 135 | test_dataloader=self.test_dataloader, 136 | test_dataset=self.test_dataset, 137 | train_dataset=self.train_dataset, 138 | dataset_name=self.env_config["dataset"], 139 | ) 140 | 141 | self.model.load_state_dict(torch.load(model_save_path)) 142 | best_model = self.model.to(self.device) 143 | 144 | _, self.test_result = test( 145 | best_model, self.test_dataloader, device=self.device, config=train_config 146 | ) 147 | _, self.val_result = test( 148 | best_model, self.val_dataloader, device=self.device, config=train_config 149 | ) 150 | 151 | self.get_score( 152 | self.test_result["forecasting"], 153 | self.test_result["reconstruction"], 154 | self.val_result["forecasting"], 155 | self.val_result["reconstruction"], 156 | self.train_config, 157 | model_save_path, 158 | self.env_config["dataset"], 159 | ) 160 | 161 | def get_loaders(self, train_dataset, batch, val_ratio=0.1): 162 | dataset_len = int(len(train_dataset)) 163 | train_use_len = int(dataset_len * (1 - val_ratio)) 164 | val_use_len = int(dataset_len * val_ratio) 165 | val_start_index = random.randrange(train_use_len) 166 | indices = torch.arange(dataset_len) 167 | 168 | train_sub_indices = torch.cat( 169 | [indices[:val_start_index], indices[val_start_index + val_use_len :]] 170 | ) 171 | train_subset = Subset(train_dataset, train_sub_indices) 172 | 173 | val_sub_indices = indices[val_start_index : val_start_index + val_use_len] 174 | val_subset = Subset(train_dataset, val_sub_indices) 175 | 176 | train_dataloader = DataLoader( 177 | train_subset, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True 178 | ) 179 | val_dataloader = DataLoader( 180 | val_subset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True 181 | ) 182 | 183 | return train_dataloader, val_dataloader 184 | 185 | def get_score( 186 | self, 187 | test_result_f, 188 | test_result_r, 189 | val_result_f, 190 | val_result_r, 191 | config, 192 | save_path, 193 | dataset_name, 194 | ): 195 | def whm(x1, x2, w1, w2): 196 | epsilon = 1e-2 197 | return (w1 + w2) * (x1 * x2) / (w1 * x2 + w2 * x1 + epsilon) 198 | 199 | _, _, test_labels_f = test_result_f 200 | test_labels = np.asarray(test_labels_f)[:, 0].tolist() 201 | 202 | alpha = config["alpha"] 203 | test_scores_f, _ = get_full_err_scores(test_result_f[:2], val_result_f[:2]) 204 | test_scores_r, _ = get_full_err_scores(test_result_r[:2], val_result_r[:2]) 205 | test_scores_whm = [] 206 | for i in range(len(test_scores_f)): 207 | score = whm( 208 | x1=test_scores_f[i], x2=test_scores_r[i], w1=alpha, w2=1.0 - alpha 209 | ) 210 | test_scores_whm.append(score) 211 | 212 | all_scores = [test_scores_f, test_scores_r, test_scores_whm] 213 | score_labels = ["Forecasting", "Reconstruction", "Weighted Harmonic Mean"] 214 | to_save = {} 215 | if self.env_config["report"] == "best": 216 | for i, scores in enumerate(all_scores): 217 | score_label = score_labels[i] 218 | if score_label != "Weighted Harmonic Mean": 219 | continue 220 | 221 | scores = np.array(scores) 222 | to_save[score_label] = scores 223 | top1_best_info = get_best_performance_data(scores, test_labels, topk=1) 224 | print( 225 | f"F1: {top1_best_info[0]:.4f} | Pr: {top1_best_info[1]:.4f} | Re: {top1_best_info[2]:.4f}" 226 | ) 227 | 228 | model_save_name = save_path.split("/")[-1].split(".")[0] 229 | results_save_path = f"./results/{dataset_name}/{model_save_name}/" 230 | if not os.path.exists(results_save_path): 231 | os.makedirs(results_save_path, exist_ok=True) 232 | 233 | for k in to_save: 234 | f2save = to_save.get(k) 235 | np.save(os.path.join(results_save_path, k), f2save) 236 | 237 | def get_save_path(self, feature_name=""): 238 | dir_path = self.env_config["dataset"] 239 | if self.datestr is None: 240 | now = datetime.now(timezone("Asia/Seoul")) 241 | self.datestr = now.strftime("%m|%d-%H:%M:%S") 242 | 243 | datestr = self.datestr 244 | paths = [ 245 | f"./pretrained/{dir_path}/best_{datestr}.pt", 246 | f"./results/{dir_path}/{datestr}.csv", 247 | ] 248 | 249 | for path in paths: 250 | dirname = os.path.dirname(path) 251 | Path(dirname).mkdir(parents=True, exist_ok=True) 252 | 253 | return paths 254 | 255 | 256 | if __name__ == "__main__": 257 | parser = argparse.ArgumentParser() 258 | 259 | parser.add_argument("-batch", help="batch size", type=int, default=32) 260 | parser.add_argument("-epoch", help="train epoch", type=int, default=50) 261 | parser.add_argument("-slide_win", help="window size", type=int, default=5) 262 | parser.add_argument("-dim", help="dimension", type=int, default=64) 263 | parser.add_argument("-slide_stride", help="window stride", type=int, default=1) 264 | parser.add_argument("-save_path_pattern", help="save path", type=str, default="") 265 | parser.add_argument("-dataset", help="hai/swat/wadi", type=str, default="swat") 266 | parser.add_argument("-device", help="cpu/cuda", type=str, default="cuda") 267 | parser.add_argument("-random_seed", help="random seed", type=int, default=-999) 268 | parser.add_argument("-comment", help="experiment comment", type=str, default="") 269 | parser.add_argument( 270 | "-out_layer_num", help="out layer dimension", type=int, default=1 271 | ) 272 | parser.add_argument( 273 | "-out_layer_inter_dim", 274 | help="intermediate out layer dimension", 275 | type=int, 276 | default=64, 277 | ) 278 | parser.add_argument("-decay", help="weight decay", type=float, default=0) 279 | parser.add_argument( 280 | "-val_ratio", help="validation data ratio", type=float, default=0.2 281 | ) 282 | parser.add_argument("-topk", help="k", type=int, default=15) 283 | parser.add_argument("-report", help="best/val", type=str, default="best") 284 | parser.add_argument( 285 | "-load_model_path", help="trained model path", type=str, default="" 286 | ) 287 | parser.add_argument("-lr", help="learning rate", type=float, default=1e-3) 288 | parser.add_argument("-gpu_id", help="gpu device ID", type=int, default=1) 289 | parser.add_argument( 290 | "-alpha", help="forecasting loss weight", type=float, default=0.5 291 | ) 292 | parser.add_argument("-beta", help="sparse loss weight", type=float, default=1.0) 293 | 294 | args = parser.parse_args() 295 | 296 | if args.random_seed < 0: 297 | args.random_seed = random.randint(0, 100) 298 | random.seed(args.random_seed) 299 | np.random.seed(args.random_seed) 300 | torch.manual_seed(args.random_seed) 301 | torch.cuda.manual_seed(args.random_seed) 302 | torch.cuda.manual_seed_all(args.random_seed) 303 | torch.backends.cudnn.benchmark = False 304 | torch.backends.cudnn.deterministic = True 305 | # torch.backends.cudnn.enabled = False 306 | os.environ["PYTHONHASHSEED"] = str(args.random_seed) 307 | 308 | train_config = { 309 | "batch": args.batch, 310 | "epoch": args.epoch, 311 | "slide_win": args.slide_win, 312 | "dim": args.dim, 313 | "slide_stride": args.slide_stride, 314 | "comment": args.comment, 315 | "seed": args.random_seed, 316 | "out_layer_num": args.out_layer_num, 317 | "out_layer_inter_dim": args.out_layer_inter_dim, 318 | "decay": args.decay, 319 | "val_ratio": args.val_ratio, 320 | "topk": args.topk, 321 | "lr": args.lr, 322 | "gpu_id": args.gpu_id, 323 | "alpha": args.alpha, 324 | "beta": args.beta, 325 | } 326 | 327 | env_config = { 328 | "save_path": args.save_path_pattern, 329 | "dataset": args.dataset, 330 | "report": args.report, 331 | "device": args.device, 332 | "load_model_path": args.load_model_path, 333 | } 334 | 335 | main = Main(train_config, env_config, debug=False) 336 | main.run() 337 | --------------------------------------------------------------------------------