├── README.md ├── data ├── synthetic_data_with_anomaly-s-1.csv └── test_anomaly.csv ├── main.py ├── model ├── __init__.py ├── convolution_lstm.py └── mscred.py └── utils ├── __init__.py ├── data.py ├── evaluate.py └── matrix_generator.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-MSCRED 2 | 3 | 这是使用PyTorch实现MSCRED 4 | 5 | 论文原文: 6 | [http://in.arxiv.org/abs/1811.08055](http://in.arxiv.org/abs/1811.08055) 7 | 8 | TensorFlow实现地址: 9 | [https://github.com/7fantasysz/MSCRED](https://github.com/7fantasysz/MSCRED) 10 | 11 | 此项目就是通过上面tensorFlow转为Pytorch,具体流程如下: 12 | - 先将时间序列数据转换为 image matrices 13 | 14 | > python ./utils/matrix_generator.py 15 | 16 | - 然后训练模型并对测试集生成相应的reconstructed matrices 17 | 18 | > python main.py 19 | 20 | - 最后评估模型,结果存在`outputs`文件夹中 21 | 22 | > python ./utils/evaluate.py -------------------------------------------------------------------------------- /data/test_anomaly.csv: -------------------------------------------------------------------------------- 1 | 11810,24,15,28 2 | 12760,21,26,5 3 | 14540,3,16,2 4 | 17790,9,5,20 5 | 18620,25,14,8 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.functional as F 4 | from tqdm import tqdm 5 | from model.mscred import MSCRED 6 | from utils.data import load_data 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import os 10 | 11 | def train(dataLoader, model, optimizer, epochs, device): 12 | model = model.to(device) 13 | print("------training on {}-------".format(device)) 14 | for epoch in range(epochs): 15 | train_l_sum,n = 0.0, 0 16 | for x in tqdm(dataLoader): 17 | x = x.to(device) 18 | x = x.squeeze() 19 | #print(type(x)) 20 | l = torch.mean((model(x)-x[-1].unsqueeze(0))**2) 21 | train_l_sum += l 22 | optimizer.zero_grad() 23 | l.backward() 24 | optimizer.step() 25 | n += 1 26 | #print("[Epoch %d/%d][Batch %d/%d] [loss: %f]" % (epoch+1, epochs, n, len(dataLoader), l.item())) 27 | 28 | print("[Epoch %d/%d] [loss: %f]" % (epoch+1, epochs, train_l_sum/n)) 29 | 30 | def test(dataLoader, model): 31 | print("------Testing-------") 32 | index = 800 33 | loss_list = [] 34 | reconstructed_data_path = "./data/matrix_data/reconstructed_data/" 35 | with torch.no_grad(): 36 | for x in dataLoader: 37 | x = x.to(device) 38 | x = x.squeeze() 39 | reconstructed_matrix = model(x) 40 | path_temp = os.path.join(reconstructed_data_path, 'reconstructed_data_' + str(index) + ".npy") 41 | np.save(path_temp, reconstructed_matrix.cpu().detach().numpy()) 42 | # l = criterion(reconstructed_matrix, x[-1].unsqueeze(0)).mean() 43 | # loss_list.append(l) 44 | # print("[test_index %d] [loss: %f]" % (index, l.item())) 45 | index += 1 46 | 47 | 48 | if __name__ == '__main__': 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | print("device is", device) 51 | dataLoader = load_data() 52 | mscred = MSCRED(3, 256) 53 | 54 | # 训练阶段 55 | # mscred.load_state_dict(torch.load("./checkpoints/model1.pth")) 56 | optimizer = torch.optim.Adam(mscred.parameters(), lr = 0.0002) 57 | train(dataLoader["train"], mscred, optimizer, 10, device) 58 | print("保存模型中....") 59 | torch.save(mscred.state_dict(), "./checkpoints/model2.pth") 60 | 61 | # # 测试阶段 62 | mscred.load_state_dict(torch.load("./checkpoints/model2.pth")) 63 | mscred.to(device) 64 | test(dataLoader["test"], mscred) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhang-Zhi-Jie/Pytorch-MSCRED/357bd56729356d38bb3e291637499c538a828e0c/model/__init__.py -------------------------------------------------------------------------------- /model/convolution_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | device = "cuda" if torch.cuda.is_available else "cpu" 6 | 7 | class ConvLSTMCell(nn.Module): 8 | def __init__(self, input_channels, hidden_channels, kernel_size): 9 | super(ConvLSTMCell, self).__init__() 10 | 11 | assert hidden_channels % 2 == 0 12 | 13 | self.input_channels = input_channels 14 | self.hidden_channels = hidden_channels 15 | self.kernel_size = kernel_size 16 | self.num_features = 4 17 | 18 | self.padding = int((kernel_size - 1) / 2) 19 | 20 | self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 21 | self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 22 | self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 23 | self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 24 | self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 25 | self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 26 | self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) 27 | self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) 28 | 29 | self.Wci = None 30 | self.Wcf = None 31 | self.Wco = None 32 | 33 | def forward(self, x, h, c): 34 | ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci) 35 | cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf) 36 | cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h)) 37 | co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco) 38 | ch = co * torch.tanh(cc) 39 | return ch, cc 40 | 41 | def init_hidden(self, batch_size, hidden, shape): 42 | if self.Wci is None: 43 | self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1])).to("cpu") 44 | self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1])).to("cpu") 45 | self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1])).to("cpu") 46 | else: 47 | assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!' 48 | assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!' 49 | return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).to("cpu"), 50 | Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).to("cpu")) 51 | 52 | 53 | class ConvLSTM(nn.Module): 54 | # input_channels corresponds to the first input feature map 55 | # hidden state is a list of succeeding lstm layers. 56 | def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]): 57 | super(ConvLSTM, self).__init__() 58 | self.input_channels = [input_channels] + hidden_channels 59 | self.hidden_channels = hidden_channels 60 | self.kernel_size = kernel_size 61 | self.num_layers = len(hidden_channels) 62 | self.step = step 63 | self.effective_step = effective_step 64 | self._all_layers = [] 65 | for i in range(self.num_layers): 66 | name = 'cell{}'.format(i) 67 | cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size) 68 | setattr(self, name, cell) 69 | self._all_layers.append(cell) 70 | 71 | def forward(self, input): 72 | internal_state = [] 73 | outputs = [] 74 | for step in range(self.step): 75 | x = input 76 | for i in range(self.num_layers): 77 | # all cells are initialized in the first step 78 | name = 'cell{}'.format(i) 79 | if step == 0: 80 | bsize, _, height, width = x.size() 81 | (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i], 82 | shape=(height, width)) 83 | internal_state.append((h, c)) 84 | 85 | # do forward 86 | (h, c) = internal_state[i] 87 | x, new_c = getattr(self, name)(x, h, c) 88 | internal_state[i] = (x, new_c) 89 | # only record effective steps 90 | if step in self.effective_step: 91 | outputs.append(x) 92 | 93 | return outputs, (x, new_c) 94 | 95 | 96 | if __name__ == '__main__': 97 | # gradient check 98 | convlstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64, 32, 32], kernel_size=3, step=5, 99 | effective_step=[4]).to("cpu") 100 | loss_fn = torch.nn.MSELoss() 101 | 102 | input = Variable(torch.randn(5, 512, 64, 32)).to("cpu") 103 | target = Variable(torch.randn(1, 32, 64, 32)).double().to("cpu") 104 | 105 | output = convlstm(input) 106 | output = output[0][0].double() 107 | print(output.shape) 108 | # res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True) 109 | # print(res) 110 | -------------------------------------------------------------------------------- /model/mscred.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from model.convolution_lstm import ConvLSTM 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | def attention(ConvLstm_out): 8 | attention_w = [] 9 | for k in range(5): 10 | attention_w.append(torch.sum(torch.mul(ConvLstm_out[k], ConvLstm_out[-1]))/5) 11 | m = nn.Softmax() 12 | attention_w = torch.reshape(m(torch.stack(attention_w)), (-1, 5)) 13 | cl_out_shape = ConvLstm_out.shape 14 | ConvLstm_out = torch.reshape(ConvLstm_out, (5, -1)) 15 | convLstmOut = torch.matmul(attention_w, ConvLstm_out) 16 | convLstmOut = torch.reshape(convLstmOut, (cl_out_shape[1], cl_out_shape[2], cl_out_shape[3])) 17 | return convLstmOut 18 | 19 | class CnnEncoder(nn.Module): 20 | def __init__(self, in_channels_encoder): 21 | super(CnnEncoder, self).__init__() 22 | self.conv1 = nn.Sequential( 23 | nn.Conv2d(in_channels_encoder, 32, 3, (1, 1), 1), 24 | nn.SELU() 25 | ) 26 | self.conv2 = nn.Sequential( 27 | nn.Conv2d(32, 64, 3, (2, 2), 1), 28 | nn.SELU() 29 | ) 30 | self.conv3 = nn.Sequential( 31 | nn.Conv2d(64, 128, 2, (2, 2), 1), 32 | nn.SELU() 33 | ) 34 | self.conv4 = nn.Sequential( 35 | nn.Conv2d(128, 256, 2, (2, 2), 0), 36 | nn.SELU() 37 | ) 38 | def forward(self, X): 39 | conv1_out = self.conv1(X) 40 | conv2_out = self.conv2(conv1_out) 41 | conv3_out = self.conv3(conv2_out) 42 | conv4_out = self.conv4(conv3_out) 43 | return conv1_out, conv2_out, conv3_out, conv4_out 44 | 45 | 46 | class Conv_LSTM(nn.Module): 47 | def __init__(self): 48 | super(Conv_LSTM, self).__init__() 49 | self.conv1_lstm = ConvLSTM(input_channels=32, hidden_channels=[32], 50 | kernel_size=3, step=5, effective_step=[4]) 51 | self.conv2_lstm = ConvLSTM(input_channels=64, hidden_channels=[64], 52 | kernel_size=3, step=5, effective_step=[4]) 53 | self.conv3_lstm = ConvLSTM(input_channels=128, hidden_channels=[128], 54 | kernel_size=3, step=5, effective_step=[4]) 55 | self.conv4_lstm = ConvLSTM(input_channels=256, hidden_channels=[256], 56 | kernel_size=3, step=5, effective_step=[4]) 57 | 58 | def forward(self, conv1_out, conv2_out, 59 | conv3_out, conv4_out): 60 | conv1_lstm_out = self.conv1_lstm(conv1_out) 61 | conv1_lstm_out = attention(conv1_lstm_out[0][0]) 62 | conv2_lstm_out = self.conv2_lstm(conv2_out) 63 | conv2_lstm_out = attention(conv2_lstm_out[0][0]) 64 | conv3_lstm_out = self.conv3_lstm(conv3_out) 65 | conv3_lstm_out = attention(conv3_lstm_out[0][0]) 66 | conv4_lstm_out = self.conv4_lstm(conv4_out) 67 | conv4_lstm_out = attention(conv4_lstm_out[0][0]) 68 | return conv1_lstm_out.unsqueeze(0), conv2_lstm_out.unsqueeze(0), conv3_lstm_out.unsqueeze(0), conv4_lstm_out.unsqueeze(0) 69 | 70 | class CnnDecoder(nn.Module): 71 | def __init__(self, in_channels): 72 | super(CnnDecoder, self).__init__() 73 | self.deconv4 = nn.Sequential( 74 | nn.ConvTranspose2d(in_channels, 128, 2, 2, 0, 0), 75 | nn.SELU() 76 | ) 77 | self.deconv3 = nn.Sequential( 78 | nn.ConvTranspose2d(256, 64, 2, 2, 1, 1), 79 | nn.SELU() 80 | ) 81 | self.deconv2 = nn.Sequential( 82 | nn.ConvTranspose2d(128, 32, 3, 2, 1, 1), 83 | nn.SELU() 84 | ) 85 | self.deconv1 = nn.Sequential( 86 | nn.ConvTranspose2d(64, 3, 3, 1, 1, 0), 87 | nn.SELU() 88 | ) 89 | 90 | def forward(self, conv1_lstm_out, conv2_lstm_out, conv3_lstm_out, conv4_lstm_out): 91 | deconv4 = self.deconv4(conv4_lstm_out) 92 | deconv4_concat = torch.cat((deconv4, conv3_lstm_out), dim = 1) 93 | deconv3 = self.deconv3(deconv4_concat) 94 | deconv3_concat = torch.cat((deconv3, conv2_lstm_out), dim = 1) 95 | deconv2 = self.deconv2(deconv3_concat) 96 | deconv2_concat = torch.cat((deconv2, conv1_lstm_out), dim = 1) 97 | deconv1 = self.deconv1(deconv2_concat) 98 | return deconv1 99 | 100 | 101 | class MSCRED(nn.Module): 102 | def __init__(self, in_channels_encoder, in_channels_decoder): 103 | super(MSCRED, self).__init__() 104 | self.cnn_encoder = CnnEncoder(in_channels_encoder) 105 | self.conv_lstm = Conv_LSTM() 106 | self.cnn_decoder = CnnDecoder(in_channels_decoder) 107 | 108 | def forward(self, x): 109 | conv1_out, conv2_out, conv3_out, conv4_out = self.cnn_encoder(x) 110 | conv1_lstm_out, conv2_lstm_out, conv3_lstm_out, conv4_lstm_out = self.conv_lstm( 111 | conv1_out, conv2_out, conv3_out, conv4_out) 112 | 113 | gen_x = self.cnn_decoder(conv1_lstm_out, conv2_lstm_out, 114 | conv3_lstm_out, conv4_lstm_out) 115 | return gen_x 116 | 117 | 118 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zhang-Zhi-Jie/Pytorch-MSCRED/357bd56729356d38bb3e291637499c538a828e0c/utils/__init__.py -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | splits = ["train", "test"] 6 | train_data_path = "./data/matrix_data/train_data/" 7 | test_data_path = "./data/matrix_data/test_data/" 8 | shuffle = {'train': True, 'test': False} 9 | 10 | def load_data(): 11 | dataset = {} 12 | train_file_list = os.listdir(train_data_path) 13 | test_file_list = os.listdir(test_data_path) 14 | train_file_list.sort(key = lambda x:int(x[11:-4])) 15 | test_file_list.sort(key = lambda x:int(x[10:-4])) 16 | train_data, test_data = [],[] 17 | for obj in train_file_list: 18 | train_file_path = train_data_path + obj 19 | train_matrix = np.load(train_file_path) 20 | #train_matrix = np.transpose(train_matrix, (0, 2, 3, 1)) 21 | train_data.append(train_matrix) 22 | 23 | for obj in test_file_list: 24 | test_file_path = test_data_path + obj 25 | test_matrix = np.load(test_file_path) 26 | #test_matrix = np.transpose(test_matrix, (0, 2, 3, 1)) 27 | test_data.append(test_matrix) 28 | 29 | dataset["train"] = torch.from_numpy(np.array(train_data)).float() 30 | dataset["test"] = torch.from_numpy(np.array(test_data)).float() 31 | 32 | dataloader = {x: torch.utils.data.DataLoader( 33 | dataset=dataset[x], batch_size=1, shuffle=shuffle[x]) 34 | for x in splits} 35 | return dataloader -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | import string 5 | import re 6 | import math 7 | import os 8 | import torch 9 | 10 | parser = argparse.ArgumentParser(description = 'MSCRED evaluation') 11 | parser.add_argument('--thred_broken', type = int, default = 0.005, 12 | help = 'broken pixel thred') 13 | parser.add_argument('--alpha', type = int, default = 1.5, 14 | help = 'scale coefficient of max valid anomaly') 15 | parser.add_argument('--valid_start_point', type = int, default = 8000, 16 | help = 'test start point') 17 | parser.add_argument('--valid_end_point', type = int, default = 10000, 18 | help = 'test end point') 19 | parser.add_argument('--test_start_point', type = int, default = 10000, 20 | help = 'test start point') 21 | parser.add_argument('--test_end_point', type = int, default = 20000, 22 | help = 'test end point') 23 | parser.add_argument('--gap_time', type = int, default = 10, 24 | help = 'gap time between each segment') 25 | parser.add_argument('--matrix_data_path', type = str, default = './data/matrix_data/', 26 | help='matrix data path') 27 | 28 | args = parser.parse_args() 29 | print(args) 30 | 31 | thred_b = args.thred_broken 32 | alpha = args.alpha 33 | gap_time = args.gap_time 34 | valid_start = args.valid_start_point//gap_time 35 | valid_end = args.valid_end_point//gap_time 36 | test_start = args.test_start_point//gap_time 37 | test_end = args.test_end_point//gap_time 38 | 39 | valid_anomaly_score = np.zeros((valid_end - valid_start , 1)) 40 | test_anomaly_score = np.zeros((test_end - test_start, 1)) 41 | 42 | matrix_data_path = args.matrix_data_path 43 | test_data_path = matrix_data_path + "test_data/" 44 | reconstructed_data_path = matrix_data_path + "reconstructed_data/" 45 | #reconstructed_data_path = matrix_data_path + "matrix_pred_data/" 46 | criterion = torch.nn.MSELoss() 47 | 48 | for i in range(valid_start, test_end): 49 | path_temp_1 = os.path.join(test_data_path, "test_data_" + str(i) + '.npy') 50 | gt_matrix_temp = np.load(path_temp_1) 51 | 52 | path_temp_2 = os.path.join(reconstructed_data_path, "reconstructed_data_" + str(i) + '.npy') 53 | #path_temp_2 = os.path.join(reconstructed_data_path, "pcc_matrix_full_test_" + str(i) + '_pred_output.npy') 54 | reconstructed_matrix_temp = np.load(path_temp_2) 55 | # reconstructed_matrix_temp = np.transpose(reconstructed_matrix_temp, [0, 3, 1, 2]) 56 | #print(reconstructed_matrix_temp.shape) 57 | #first (short) duration scale for evaluation 58 | select_gt_matrix = np.array(gt_matrix_temp)[-1][0] #get last step matrix 59 | 60 | select_reconstructed_matrix = np.array(reconstructed_matrix_temp)[0][0] 61 | 62 | #compute number of broken element in residual matrix 63 | select_matrix_error = np.square(np.subtract(select_gt_matrix, select_reconstructed_matrix)) 64 | num_broken = len(select_matrix_error[select_matrix_error > thred_b]) 65 | 66 | #print num_broken 67 | if i < valid_end: 68 | valid_anomaly_score[i - valid_start] = num_broken 69 | else: 70 | test_anomaly_score[i - test_start] = num_broken 71 | valid_anomaly_max = np.max(valid_anomaly_score.ravel()) 72 | test_anomaly_score = test_anomaly_score.ravel() 73 | #print(test_anomaly_score) 74 | # plot anomaly score curve and identification result 75 | anomaly_pos = np.zeros(5) 76 | root_cause_gt = np.zeros((5, 3)) 77 | anomaly_span = [10, 30, 90] 78 | root_cause_f = open("./data/test_anomaly.csv", "r") 79 | row_index = 0 80 | for line in root_cause_f: 81 | line=line.strip() 82 | anomaly_axis = int(re.split(',',line)[0]) 83 | anomaly_pos[row_index] = anomaly_axis/gap_time - test_start - anomaly_span[row_index%3]/gap_time 84 | #print(anomaly_pos[row_index]) 85 | root_list = re.split(',',line)[1:] 86 | for k in range(len(root_list)-1): 87 | root_cause_gt[row_index][k] = int(root_list[k]) 88 | row_index += 1 89 | root_cause_f.close() 90 | 91 | fig, axes = plt.subplots() 92 | #plt.plot(test_anomaly_score, 'black', linewidth = 2) 93 | test_num = test_end - test_start 94 | # plt.xticks(fontsize = 25) 95 | # plt.ylim((0, 100)) 96 | # plt.yticks(np.arange(0, 101, 20), fontsize = 25) 97 | plt.plot(test_anomaly_score, color = 'black', linewidth = 2) 98 | threshold = np.full((test_num), valid_anomaly_max * alpha) 99 | axes.plot(threshold, color = 'black', linestyle = '--',linewidth = 2) 100 | for k in range(len(anomaly_pos)): 101 | axes.axvspan(anomaly_pos[k], anomaly_pos[k] + anomaly_span[k%3]/gap_time, color='red', linewidth=2) 102 | labels = [' ', '0e3', '2e3', '4e3', '6e3', '8e3', '10e3'] 103 | # axes.set_xticklabels(labels, rotation = 25, fontsize = 20) 104 | plt.xlabel('Test Time', fontsize = 25) 105 | plt.ylabel('Anomaly Score', fontsize = 25) 106 | axes.spines['right'].set_visible(False) 107 | axes.spines['top'].set_visible(False) 108 | axes.yaxis.set_ticks_position('left') 109 | axes.xaxis.set_ticks_position('bottom') 110 | fig.subplots_adjust(bottom=0.25) 111 | fig.subplots_adjust(left=0.25) 112 | plt.title("MSCRED", size = 25) 113 | plt.savefig('./outputs/anomaly_score.jpg') 114 | plt.show() 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /utils/matrix_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import pandas as pd 4 | import os, sys 5 | import math 6 | import scipy 7 | #import matplotlib.pyplot as plt 8 | from scipy.stats import pearsonr 9 | from scipy import spatial 10 | import itertools as it 11 | import string 12 | import re 13 | 14 | 15 | 16 | parser = argparse.ArgumentParser(description = 'Signature Matrix Generator') 17 | parser.add_argument('--ts_type', type = str, default = "node", 18 | help = 'type of time series: node or link') 19 | parser.add_argument('--step_max', type = int, default = 5, 20 | help = 'maximum step in ConvLSTM') 21 | parser.add_argument('--gap_time', type = int, default = 10, # tride width... 22 | help = 'gap time between each segment') 23 | parser.add_argument('--win_size', type = int, default = [10, 30, 60], 24 | help = 'window size of each segment') 25 | parser.add_argument('--min_time', type = int, default = 0, 26 | help = 'minimum time point') 27 | parser.add_argument('--max_time', type = int, default = 20000, 28 | help = 'maximum time point') 29 | parser.add_argument('--train_start_point', type = int, default = 0, 30 | help = 'train start point') 31 | parser.add_argument('--train_end_point', type = int, default = 8000, 32 | help = 'train end point') 33 | parser.add_argument('--test_start_point', type = int, default = 8000, 34 | help = 'test start point') 35 | parser.add_argument('--test_end_point', type = int, default = 20000, 36 | help = 'test end point') 37 | parser.add_argument('--raw_data_path', type = str, default = './data/synthetic_data_with_anomaly-s-1.csv', 38 | help='path to load raw data') 39 | parser.add_argument('--save_data_path', type = str, default = './data/', 40 | help='path to save data') 41 | 42 | args = parser.parse_args() 43 | print(args) 44 | 45 | ts_type = args.ts_type 46 | step_max = args.step_max 47 | min_time = args.min_time 48 | max_time = args.max_time 49 | gap_time = args.gap_time 50 | win_size = args.win_size 51 | 52 | train_start = args.train_start_point 53 | train_end = args.train_end_point 54 | test_start = args.test_start_point 55 | test_end = args.test_end_point 56 | 57 | raw_data_path = args.raw_data_path 58 | save_data_path = args.save_data_path 59 | 60 | ts_colname="agg_time_interval" 61 | agg_freq='5min' 62 | 63 | matrix_data_path = save_data_path + "matrix_data/" 64 | if not os.path.exists(matrix_data_path): 65 | os.makedirs(matrix_data_path) 66 | 67 | 68 | def generate_signature_matrix_node(): 69 | data = np.array(pd.read_csv(raw_data_path, header = None), dtype=np.float64) 70 | sensor_n = data.shape[0] 71 | # min-max normalization 72 | max_value = np.max(data, axis=1) 73 | min_value = np.min(data, axis=1) 74 | data = (np.transpose(data) - min_value)/(max_value - min_value + 1e-6) 75 | data = np.transpose(data) 76 | 77 | #multi-scale signature matix generation 78 | for w in range(len(win_size)): 79 | matrix_all = [] 80 | win = win_size[w] 81 | print ("generating signature with window " + str(win) + "...") 82 | for t in range(min_time, max_time, gap_time): 83 | #print t 84 | matrix_t = np.zeros((sensor_n, sensor_n)) 85 | if t >= 60: 86 | for i in range(sensor_n): 87 | for j in range(i, sensor_n): 88 | #if np.var(data[i, t - win:t]) and np.var(data[j, t - win:t]): 89 | matrix_t[i][j] = np.inner(data[i, t - win:t], data[j, t - win:t])/(win) # rescale by win 90 | matrix_t[j][i] = matrix_t[i][j] 91 | matrix_all.append(matrix_t) 92 | path_temp = matrix_data_path + "matrix_win_" + str(win) 93 | np.save(path_temp, matrix_all) 94 | del matrix_all[:] 95 | 96 | print ("matrix generation finish!") 97 | 98 | def generate_train_test_data(): 99 | #data sample generation 100 | print ("generating train/test data samples...") 101 | matrix_data_path = save_data_path + "matrix_data/" 102 | 103 | train_data_path = matrix_data_path + "train_data/" 104 | if not os.path.exists(train_data_path): 105 | os.makedirs(train_data_path) 106 | test_data_path = matrix_data_path + "test_data/" 107 | if not os.path.exists(test_data_path): 108 | os.makedirs(test_data_path) 109 | 110 | data_all = [] 111 | # for value_col in value_colnames: 112 | for w in range(len(win_size)): 113 | #path_temp = matrix_data_path + "matrix_win_" + str(win_size[w]) + str(value_col) + ".npy" 114 | path_temp = matrix_data_path + "matrix_win_" + str(win_size[w]) + ".npy" 115 | data_all.append(np.load(path_temp)) 116 | 117 | train_test_time = [[train_start, train_end], [test_start, test_end]] 118 | for i in range(len(train_test_time)): 119 | for data_id in range(int(train_test_time[i][0]/gap_time), int(train_test_time[i][1]/gap_time)): 120 | #print data_id 121 | step_multi_matrix = [] 122 | for step_id in range(step_max, 0, -1): 123 | multi_matrix = [] 124 | # for k in range(len(value_colnames)): 125 | for i in range(len(win_size)): 126 | multi_matrix.append(data_all[i][data_id - step_id]) 127 | step_multi_matrix.append(multi_matrix) 128 | 129 | if data_id >= (train_start/gap_time + win_size[-1]/gap_time + step_max) and data_id < (train_end/gap_time): # remove start points with invalid value 130 | path_temp = os.path.join(train_data_path, 'train_data_' + str(data_id)) 131 | np.save(path_temp, step_multi_matrix) 132 | elif data_id >= (test_start/gap_time) and data_id < (test_end/gap_time): 133 | path_temp = os.path.join(test_data_path, 'test_data_' + str(data_id)) 134 | np.save(path_temp, step_multi_matrix) 135 | 136 | #print np.shape(step_multi_matrix) 137 | 138 | del step_multi_matrix[:] 139 | 140 | print ("train/test data generation finish!") 141 | 142 | 143 | if __name__ == '__main__': 144 | '''need one more dimension to manage mulitple "features" for each node or link in each time point, 145 | this multiple features can be simply added as extra channels 146 | ''' 147 | 148 | if ts_type == "node": 149 | generate_signature_matrix_node() 150 | 151 | generate_train_test_data() --------------------------------------------------------------------------------