├── model.pth ├── README.md ├── LICENSE ├── utils.py ├── usad.py ├── gdrivedl.py └── USAD.ipynb /model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manigalati/usad/HEAD/model.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # USAD - UnSupervised Anomaly Detection on multivariate time series 2 | 3 | Scripts and utility programs for implementing the USAD architecture. 4 | 5 | Implementation by: Francesco Galati. 6 | 7 | Additional contributions: Julien Audibert, Maria A. Zuluaga. 8 | 9 | ## How to cite 10 | 11 | If you use this software, please cite the following paper as appropriate: 12 | 13 | Audibert, J., Michiardi, P., Guyard, F., Marti, S., Zuluaga, M. A. (2020). 14 | USAD : UnSupervised Anomaly Detection on multivariate time series. 15 | Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, August 23-27, 2020 16 | 17 | ## Requirements 18 | * PyTorch 1.6.0 19 | * CUDA 10.1 (to allow use of GPU, not compulsory) 20 | 21 | ## Running the Software 22 | 23 | All the python classes and functions strictly needed to implement the USAD architecture can be found in `usad.py`. 24 | An example of an application deployed with the [SWaT dataset] is included in `USAD.ipynb`. 25 | 26 | ## Copyright and licensing 27 | 28 | Copyright 2020 Eurecom. 29 | 30 | This software is released under the BSD-3 license. Please see the license file_ for details. 31 | 32 | ## Publication 33 | 34 | Audibert et al. [USAD : UnSupervised Anomaly Detection on multivariate time series]. 2020 35 | 36 | [SWaT dataset]: https://itrust.sutd.edu.sg/itrust-labs_datasets/dataset_info/#swat 37 | [USAD : UnSupervised Anomaly Detection on multivariate time series]: https://dl.acm.org/doi/pdf/10.1145/3394486.3403392 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | BSD License 3 | 4 | Copyright (c) 2020, EURECOM 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the documentation and/or 15 | other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 24 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 25 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 26 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 28 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 29 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 30 | OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | 33 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import torch 6 | 7 | from sklearn.metrics import roc_curve,roc_auc_score 8 | 9 | def get_default_device(): 10 | """Pick GPU if available, else CPU""" 11 | if torch.cuda.is_available(): 12 | return torch.device('cuda') 13 | else: 14 | return torch.device('cpu') 15 | 16 | def to_device(data, device): 17 | """Move tensor(s) to chosen device""" 18 | if isinstance(data, (list,tuple)): 19 | return [to_device(x, device) for x in data] 20 | return data.to(device, non_blocking=True) 21 | 22 | def plot_history(history): 23 | losses1 = [x['val_loss1'] for x in history] 24 | losses2 = [x['val_loss2'] for x in history] 25 | plt.plot(losses1, '-x', label="loss1") 26 | plt.plot(losses2, '-x', label="loss2") 27 | plt.xlabel('epoch') 28 | plt.ylabel('loss') 29 | plt.legend() 30 | plt.title('Losses vs. No. of epochs') 31 | plt.grid() 32 | plt.show() 33 | 34 | def histogram(y_test,y_pred): 35 | plt.figure(figsize=(12,6)) 36 | plt.hist([y_pred[y_test==0], 37 | y_pred[y_test==1]], 38 | bins=20, 39 | color = ['#82E0AA','#EC7063'],stacked=True) 40 | plt.title("Results",size=20) 41 | plt.grid() 42 | plt.show() 43 | 44 | def ROC(y_test,y_pred): 45 | fpr,tpr,tr=roc_curve(y_test,y_pred) 46 | auc=roc_auc_score(y_test,y_pred) 47 | idx=np.argwhere(np.diff(np.sign(tpr-(1-fpr)))).flatten() 48 | 49 | plt.xlabel("FPR") 50 | plt.ylabel("TPR") 51 | plt.plot(fpr,tpr,label="AUC="+str(auc)) 52 | plt.plot(fpr,1-fpr,'r:') 53 | plt.plot(fpr[idx],tpr[idx], 'ro') 54 | plt.legend(loc=4) 55 | plt.grid() 56 | plt.show() 57 | return tr[idx] 58 | 59 | def confusion_matrix(target, predicted, perc=False): 60 | 61 | data = {'y_Actual': target, 62 | 'y_Predicted': predicted 63 | } 64 | df = pd.DataFrame(data, columns=['y_Predicted','y_Actual']) 65 | confusion_matrix = pd.crosstab(df['y_Predicted'], df['y_Actual'], rownames=['Predicted'], colnames=['Actual']) 66 | 67 | if perc: 68 | sns.heatmap(confusion_matrix/np.sum(confusion_matrix), annot=True, fmt='.2%', cmap='Blues') 69 | else: 70 | sns.heatmap(confusion_matrix, annot=True, fmt='d') 71 | plt.show() -------------------------------------------------------------------------------- /usad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils import * 5 | device = get_default_device() 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, in_size, latent_size): 9 | super().__init__() 10 | self.linear1 = nn.Linear(in_size, int(in_size/2)) 11 | self.linear2 = nn.Linear(int(in_size/2), int(in_size/4)) 12 | self.linear3 = nn.Linear(int(in_size/4), latent_size) 13 | self.relu = nn.ReLU(True) 14 | 15 | def forward(self, w): 16 | out = self.linear1(w) 17 | out = self.relu(out) 18 | out = self.linear2(out) 19 | out = self.relu(out) 20 | out = self.linear3(out) 21 | z = self.relu(out) 22 | return z 23 | 24 | class Decoder(nn.Module): 25 | def __init__(self, latent_size, out_size): 26 | super().__init__() 27 | self.linear1 = nn.Linear(latent_size, int(out_size/4)) 28 | self.linear2 = nn.Linear(int(out_size/4), int(out_size/2)) 29 | self.linear3 = nn.Linear(int(out_size/2), out_size) 30 | self.relu = nn.ReLU(True) 31 | self.sigmoid = nn.Sigmoid() 32 | 33 | def forward(self, z): 34 | out = self.linear1(z) 35 | out = self.relu(out) 36 | out = self.linear2(out) 37 | out = self.relu(out) 38 | out = self.linear3(out) 39 | w = self.sigmoid(out) 40 | return w 41 | 42 | class UsadModel(nn.Module): 43 | def __init__(self, w_size, z_size): 44 | super().__init__() 45 | self.encoder = Encoder(w_size, z_size) 46 | self.decoder1 = Decoder(z_size, w_size) 47 | self.decoder2 = Decoder(z_size, w_size) 48 | 49 | def training_step(self, batch, n): 50 | z = self.encoder(batch) 51 | w1 = self.decoder1(z) 52 | w2 = self.decoder2(z) 53 | w3 = self.decoder2(self.encoder(w1)) 54 | loss1 = 1/n*torch.mean((batch-w1)**2)+(1-1/n)*torch.mean((batch-w3)**2) 55 | loss2 = 1/n*torch.mean((batch-w2)**2)-(1-1/n)*torch.mean((batch-w3)**2) 56 | return loss1,loss2 57 | 58 | def validation_step(self, batch, n): 59 | with torch.no_grad(): 60 | z = self.encoder(batch) 61 | w1 = self.decoder1(z) 62 | w2 = self.decoder2(z) 63 | w3 = self.decoder2(self.encoder(w1)) 64 | loss1 = 1/n*torch.mean((batch-w1)**2)+(1-1/n)*torch.mean((batch-w3)**2) 65 | loss2 = 1/n*torch.mean((batch-w2)**2)-(1-1/n)*torch.mean((batch-w3)**2) 66 | return {'val_loss1': loss1, 'val_loss2': loss2} 67 | 68 | def validation_epoch_end(self, outputs): 69 | batch_losses1 = [x['val_loss1'] for x in outputs] 70 | epoch_loss1 = torch.stack(batch_losses1).mean() 71 | batch_losses2 = [x['val_loss2'] for x in outputs] 72 | epoch_loss2 = torch.stack(batch_losses2).mean() 73 | return {'val_loss1': epoch_loss1.item(), 'val_loss2': epoch_loss2.item()} 74 | 75 | def epoch_end(self, epoch, result): 76 | print("Epoch [{}], val_loss1: {:.4f}, val_loss2: {:.4f}".format(epoch, result['val_loss1'], result['val_loss2'])) 77 | 78 | def evaluate(model, val_loader, n): 79 | outputs = [model.validation_step(to_device(batch,device), n) for [batch] in val_loader] 80 | return model.validation_epoch_end(outputs) 81 | 82 | def training(epochs, model, train_loader, val_loader, opt_func=torch.optim.Adam): 83 | history = [] 84 | optimizer1 = opt_func(list(model.encoder.parameters())+list(model.decoder1.parameters())) 85 | optimizer2 = opt_func(list(model.encoder.parameters())+list(model.decoder2.parameters())) 86 | for epoch in range(epochs): 87 | for [batch] in train_loader: 88 | batch=to_device(batch,device) 89 | 90 | #Train AE1 91 | loss1,loss2 = model.training_step(batch,epoch+1) 92 | loss1.backward() 93 | optimizer1.step() 94 | optimizer1.zero_grad() 95 | 96 | 97 | #Train AE2 98 | loss1,loss2 = model.training_step(batch,epoch+1) 99 | loss2.backward() 100 | optimizer2.step() 101 | optimizer2.zero_grad() 102 | 103 | 104 | result = evaluate(model, val_loader, epoch+1) 105 | model.epoch_end(epoch, result) 106 | history.append(result) 107 | return history 108 | 109 | def testing(model, test_loader, alpha=.5, beta=.5): 110 | results=[] 111 | with torch.no_grad(): 112 | for [batch] in test_loader: 113 | batch=to_device(batch,device) 114 | w1=model.decoder1(model.encoder(batch)) 115 | w2=model.decoder2(model.encoder(w1)) 116 | results.append(alpha*torch.mean((batch-w1)**2,axis=1)+beta*torch.mean((batch-w2)**2,axis=1)) 117 | return results -------------------------------------------------------------------------------- /gdrivedl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import unicode_literals 3 | import json 4 | import os 5 | import re 6 | import sys 7 | import unicodedata 8 | 9 | try: 10 | #Python3 11 | from urllib.request import Request, urlopen 12 | except ImportError: 13 | #Python2 14 | from urllib2 import Request, urlopen 15 | 16 | ITEM_URL = 'https://drive.google.com/open?id={id}' 17 | FILE_URL = 'https://docs.google.com/uc?export=download&id={id}&confirm={confirm}' 18 | FOLDER_URL = 'https://drive.google.com/drive/folders/{id}' 19 | 20 | ID_PATTERNS = [ 21 | re.compile('/file/d/([0-9A-Za-z_-]{10,})(?:/|$)', re.IGNORECASE), 22 | re.compile('id=([0-9A-Za-z_-]{10,})(?:&|$)', re.IGNORECASE), 23 | re.compile('([0-9A-Za-z_-]{10,})', re.IGNORECASE) 24 | ] 25 | FILE_PATTERN = re.compile("itemJson: (\[.*?)};", 26 | re.DOTALL | re.IGNORECASE) 27 | FOLDER_PATTERN = re.compile("window\['_DRIVE_ivd'\] = '(.*?)';", 28 | re.DOTALL | re.IGNORECASE) 29 | CONFIRM_PATTERN = re.compile("download_warning[0-9A-Za-z_-]+=([0-9A-Za-z_-]+);", 30 | re.IGNORECASE) 31 | FOLDER_TYPE = 'application/vnd.google-apps.folder' 32 | 33 | def output(text): 34 | try: 35 | sys.stdout.write(text) 36 | except UnicodeEncodeError: 37 | sys.stdout.write(text.encode('utf8')) 38 | 39 | # Big thanks to leo_wallentin for below sanitize function (modified slightly for this script) 40 | # https://gitlab.com/jplusplus/sanitize-filename/-/blob/master/sanitize_filename/sanitize_filename.py 41 | def sanitize(filename): 42 | blacklist = ["\\", "/", ":", "*", "?", "\"", "<", ">", "|", "\0"] 43 | reserved = [ 44 | "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", 45 | "COM6", "COM7", "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", 46 | "LPT6", "LPT7", "LPT8", "LPT9", 47 | ] 48 | 49 | filename = "".join(c for c in filename if c not in blacklist) 50 | filename = "".join(c for c in filename if 31 < ord(c)) 51 | filename = unicodedata.normalize("NFKD", filename) 52 | filename = filename.rstrip(". ") 53 | filename = filename.strip() 54 | 55 | if all([x == "." for x in filename]): 56 | filename = "_" + filename 57 | if filename in reserved: 58 | filename = "_" + filename 59 | if len(filename) == 0: 60 | filename = "_" 61 | if len(filename) > 255: 62 | parts = re.split(r"/|\\", filename)[-1].split(".") 63 | if len(parts) > 1: 64 | ext = "." + parts.pop() 65 | filename = filename[:-len(ext)] 66 | else: 67 | ext = "" 68 | if filename == "": 69 | filename = "_" 70 | if len(ext) > 254: 71 | ext = ext[254:] 72 | maxl = 255 - len(ext) 73 | filename = filename[:maxl] 74 | filename = filename + ext 75 | filename = filename.rstrip(". ") 76 | if len(filename) == 0: 77 | filename = "_" 78 | 79 | return filename 80 | 81 | 82 | def process_item(id, directory): 83 | url = ITEM_URL.format(id=id) 84 | resp = urlopen(url) 85 | url = resp.geturl() 86 | html = resp.read().decode('utf-8') 87 | 88 | if '/file/' in url: 89 | match = FILE_PATTERN.search(html) 90 | data = match.group(1).replace('\/', '/') 91 | data = data.replace(r'\x5b', '[').replace(r'\x22', '"').replace(r'\x5d', ']').replace(r'\n','') 92 | data = json.loads(data) 93 | 94 | file_name = sanitize(data[1]) 95 | file_size = int(data[25][2]) 96 | file_path = os.path.join(directory, file_name) 97 | 98 | process_file(id, file_path, file_size) 99 | elif '/folders/' in url: 100 | process_folder(id, directory, html=html) 101 | elif 'ServiceLogin' in url: 102 | sys.stderr.write('Id {} does not have link sharing enabled'.format(id)) 103 | sys.exit(1) 104 | else: 105 | sys.stderr.write('That id {} returned an unknown url'.format(id)) 106 | sys.exit(1) 107 | 108 | 109 | def process_folder(id, directory, html=None): 110 | if not html: 111 | url = FOLDER_URL.format(id=id) 112 | html = urlopen(url).read().decode('utf-8') 113 | 114 | match = FOLDER_PATTERN.search(html) 115 | data = match.group(1).replace('\/', '/') 116 | data = data.replace(r'\x5b', '[').replace(r'\x22', '"').replace(r'\x5d', ']').replace(r'\n','') 117 | data = json.loads(data) 118 | 119 | if not os.path.exists(directory): 120 | os.mkdir(directory) 121 | output('Directory: {directory} [Created]\n'.format(directory=directory)) 122 | else: 123 | output('Directory: {directory} [Exists]\n'.format(directory=directory)) 124 | 125 | if not data[0]: 126 | return 127 | 128 | for item in sorted(data[0], key=lambda i: i[3] == FOLDER_TYPE): 129 | item_id = item[0] 130 | item_name = sanitize(item[2]) 131 | item_type = item[3] 132 | item_size = item[13] 133 | item_path = os.path.join(directory, item_name) 134 | 135 | if item_type == FOLDER_TYPE: 136 | process_folder(item_id, item_path) 137 | else: 138 | process_file(item_id, item_path, int(item_size)) 139 | 140 | 141 | def process_file(id, file_path, file_size, confirm='', cookies=''): 142 | if os.path.exists(file_path): 143 | output('{file_path} [Exists]\n'.format(file_path=file_path)) 144 | return 145 | 146 | url = FILE_URL.format(id=id, confirm=confirm) 147 | req = Request(url, headers={'Cookie': cookies, 148 | 'User-Agent': 'Mozilla/5.0'}) 149 | resp = urlopen(req) 150 | cookies = resp.headers.get('Set-Cookie') or '' 151 | 152 | if not confirm and 'download_warning' in cookies: 153 | confirm = CONFIRM_PATTERN.search(cookies) 154 | return process_file(id, file_path, file_size, confirm.group(1), cookies) 155 | 156 | output(file_path + '\n') 157 | 158 | try: 159 | with open(file_path, 'wb') as f: 160 | dl = 0 161 | while True: 162 | chunk = resp.read(4096) 163 | if not chunk: 164 | break 165 | 166 | if b'Too many users have viewed or downloaded this file recently' in chunk: 167 | raise Exception('Quota exceeded for this file') 168 | 169 | dl += len(chunk) 170 | f.write(chunk) 171 | done = int(50 * dl / file_size) 172 | output("\r[{}{}] {:.2f}MB/{:.2f}MB".format( 173 | '=' * done, 174 | ' ' * 175 | (50 - done), 176 | dl / 1024 / 1024, 177 | file_size / 1024 / 1024 178 | )) 179 | sys.stdout.flush() 180 | except: 181 | if os.path.exists(file_path): 182 | os.remove(file_path) 183 | raise 184 | 185 | output('\n') 186 | 187 | 188 | def get_arg(pos, default=None): 189 | try: 190 | return sys.argv[pos] 191 | except IndexError: 192 | return default 193 | 194 | 195 | if __name__ == '__main__': 196 | url = get_arg(1, '').strip() 197 | directory = get_arg(2, './').strip() 198 | id = '' 199 | 200 | if not url: 201 | sys.stderr.write('A Google Drive URL is required') 202 | sys.exit(1) 203 | 204 | for pattern in ID_PATTERNS: 205 | match = pattern.search(url) 206 | if match: 207 | id = match.group(1) 208 | break 209 | 210 | if not id: 211 | sys.stderr.write('Unable to get ID from {}'.format(url)) 212 | sys.exit(1) 213 | 214 | process_item(id, directory) 215 | -------------------------------------------------------------------------------- /USAD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "etniX_KTlJ5U" 7 | }, 8 | "source": [ 9 | "# USAD" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "N3jM0qLU8MgZ" 16 | }, 17 | "source": [ 18 | "## Environment" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": { 25 | "id": "rjheCL2b1Rnw" 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "rm: cannot remove 'sample_data': No such file or directory\r\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "!rm -r sample_data" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/", 46 | "height": 118 47 | }, 48 | "id": "e3dDxs8LFZdT", 49 | "outputId": "ebff804d-1c59-4039-d869-f65907b19712" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "!git clone https://github.com/manigalati/usad" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/", 62 | "height": 34 63 | }, 64 | "id": "te9stFZtFfZu", 65 | "outputId": "3ca36b3b-dd9a-413c-873f-ab730285ad51" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "%cd usad" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": { 76 | "id": "6u1DGKsAlLF-" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "import numpy as np\n", 81 | "import pandas as pd\n", 82 | "import matplotlib.pyplot as plt\n", 83 | "import seaborn as sns\n", 84 | "import torch\n", 85 | "import torch.nn as nn\n", 86 | "\n", 87 | "from utils import *\n", 88 | "from usad import *" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": { 95 | "colab": { 96 | "base_uri": "https://localhost:8080/", 97 | "height": 34 98 | }, 99 | "id": "4AzWlDBI_djV", 100 | "outputId": "7a8d0c19-2389-461b-c0be-3427a25dda91" 101 | }, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "GPU 0: Quadro P6000 (UUID: GPU-e16b9553-c966-4659-d528-7376969c0e91)\n", 108 | "GPU 1: Quadro P6000 (UUID: GPU-def1ffb6-415d-a3f6-288b-94256f1ba88f)\n", 109 | "GPU 2: GeForce GTX 1080 Ti (UUID: GPU-075162a2-c2cc-7757-e07d-e1260458102e)\n", 110 | "GPU 3: GeForce GTX 1080 Ti (UUID: GPU-078c9ebd-10e3-2644-2267-bfcf3135c6a1)\n", 111 | "GPU 4: GeForce GTX 1080 Ti (UUID: GPU-db4d0970-82a3-4f24-d69f-423377f7d3c0)\n", 112 | "GPU 5: GeForce GTX 1080 Ti (UUID: GPU-945cc499-5f5f-ee9f-5b21-69a0e2e06535)\n", 113 | "GPU 6: GeForce GTX 1080 Ti (UUID: GPU-07966339-f324-1615-decf-f4825e865008)\n", 114 | "GPU 7: GeForce GTX 1080 Ti (UUID: GPU-dfa6e9f2-3de4-2115-3363-4329ff29ba3f)\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "!nvidia-smi -L\n", 120 | "\n", 121 | "device = get_default_device()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": { 127 | "id": "1crx5rGP9ONf" 128 | }, 129 | "source": [ 130 | "## EDA - Data Pre-Processing" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": { 136 | "id": "vxofeE469RhT" 137 | }, 138 | "source": [ 139 | "### Download dataset" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": { 146 | "colab": { 147 | "base_uri": "https://localhost:8080/", 148 | "height": 84 149 | }, 150 | "id": "i95DlAZI1G_p", 151 | "outputId": "5b35771c-356e-4e0b-a997-682d1ea85c6a", 152 | "scrolled": false 153 | }, 154 | "outputs": [ 155 | { 156 | "name": "stdout", 157 | "output_type": "stream", 158 | "text": [ 159 | "mkdir: cannot create directory 'input': File exists\n", 160 | "input/SWaT_Dataset_Normal_v1.csv [Exists]\n", 161 | "input/SWaT_Dataset_Attack_v0.csv [Exists]\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "!mkdir input\n", 167 | "#normal period\n", 168 | "!python gdrivedl.py https://drive.google.com/open?id=1rVJ5ry5GG-ZZi5yI4x9lICB8VhErXwCw input/\n", 169 | "#anomalies\n", 170 | "!python gdrivedl.py https://drive.google.com/open?id=1iDYc0OEmidN712fquOBRFjln90SbpaE7 input/" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "id": "kfSj4FYL9W8Y" 177 | }, 178 | "source": [ 179 | "### Normal period" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 7, 185 | "metadata": { 186 | "colab": { 187 | "base_uri": "https://localhost:8080/", 188 | "height": 87 189 | }, 190 | "id": "XeDLxV_r1G9n", 191 | "outputId": "576538dd-64f2-46fa-8e6f-6c2ffdebad15" 192 | }, 193 | "outputs": [ 194 | { 195 | "name": "stderr", 196 | "output_type": "stream", 197 | "text": [ 198 | "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3058: DtypeWarning: Columns (26) have mixed types.Specify dtype option on import or set low_memory=False.\n", 199 | " interactivity=interactivity, compiler=compiler, result=result)\n" 200 | ] 201 | }, 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "(495000, 51)" 206 | ] 207 | }, 208 | "execution_count": 7, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "#Read data\n", 215 | "normal = pd.read_csv(\"input/SWaT_Dataset_Normal_v1.csv\")#, nrows=1000)\n", 216 | "normal = normal.drop([\"Timestamp\" , \"Normal/Attack\" ] , axis = 1)\n", 217 | "normal.shape" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 8, 223 | "metadata": { 224 | "id": "fFuLm1GH1G2n" 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "# Transform all columns into float64\n", 229 | "for i in list(normal): \n", 230 | " normal[i]=normal[i].apply(lambda x: str(x).replace(\",\" , \".\"))\n", 231 | "normal = normal.astype(float)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": { 237 | "id": "zxFNH5kU9hIE" 238 | }, 239 | "source": [ 240 | "#### Normalization" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 9, 246 | "metadata": { 247 | "id": "Mfxj4Uxn9kv4" 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "from sklearn import preprocessing\n", 252 | "min_max_scaler = preprocessing.MinMaxScaler()\n", 253 | "\n", 254 | "x = normal.values\n", 255 | "x_scaled = min_max_scaler.fit_transform(x)\n", 256 | "normal = pd.DataFrame(x_scaled)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 10, 262 | "metadata": { 263 | "colab": { 264 | "base_uri": "https://localhost:8080/", 265 | "height": 126 266 | }, 267 | "id": "mQ6_U4jn9nlw", 268 | "outputId": "f1cc1bd6-f1cc-4764-b1cc-2fd989ac4918" 269 | }, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/html": [ 274 | "
\n", 275 | "\n", 288 | "\n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | "
0123456789...41424344454647484950
00.00.0052940.50.00.00.0122910.0750990.0020090.00.5...0.00.00.00.0008140.00.0012170.0001470.00.00.0
10.00.0054070.50.00.00.0122910.0750990.0020090.00.5...0.00.00.00.0008140.00.0012170.0001470.00.00.0
\n", 366 | "

2 rows × 51 columns

\n", 367 | "
" 368 | ], 369 | "text/plain": [ 370 | " 0 1 2 3 4 5 6 7 8 9 ... \\\n", 371 | "0 0.0 0.005294 0.5 0.0 0.0 0.012291 0.075099 0.002009 0.0 0.5 ... \n", 372 | "1 0.0 0.005407 0.5 0.0 0.0 0.012291 0.075099 0.002009 0.0 0.5 ... \n", 373 | "\n", 374 | " 41 42 43 44 45 46 47 48 49 50 \n", 375 | "0 0.0 0.0 0.0 0.000814 0.0 0.001217 0.000147 0.0 0.0 0.0 \n", 376 | "1 0.0 0.0 0.0 0.000814 0.0 0.001217 0.000147 0.0 0.0 0.0 \n", 377 | "\n", 378 | "[2 rows x 51 columns]" 379 | ] 380 | }, 381 | "execution_count": 10, 382 | "metadata": {}, 383 | "output_type": "execute_result" 384 | } 385 | ], 386 | "source": [ 387 | "normal.head(2)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": { 393 | "id": "_i71RFAi9spa" 394 | }, 395 | "source": [ 396 | "### Attack" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 11, 402 | "metadata": { 403 | "colab": { 404 | "base_uri": "https://localhost:8080/", 405 | "height": 87 406 | }, 407 | "id": "aN_TFp5x9uTE", 408 | "outputId": "38d7993d-c9a3-461d-c430-ebde697afbc6" 409 | }, 410 | "outputs": [ 411 | { 412 | "name": "stderr", 413 | "output_type": "stream", 414 | "text": [ 415 | "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py:3058: DtypeWarning: Columns (1,9,28,46) have mixed types.Specify dtype option on import or set low_memory=False.\n", 416 | " interactivity=interactivity, compiler=compiler, result=result)\n" 417 | ] 418 | }, 419 | { 420 | "data": { 421 | "text/plain": [ 422 | "(449919, 51)" 423 | ] 424 | }, 425 | "execution_count": 11, 426 | "metadata": {}, 427 | "output_type": "execute_result" 428 | } 429 | ], 430 | "source": [ 431 | "#Read data\n", 432 | "attack = pd.read_csv(\"input/SWaT_Dataset_Attack_v0.csv\",sep=\";\")#, nrows=1000)\n", 433 | "labels = [ float(label!= 'Normal' ) for label in attack[\"Normal/Attack\"].values]\n", 434 | "attack = attack.drop([\"Timestamp\" , \"Normal/Attack\" ] , axis = 1)\n", 435 | "attack.shape" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 12, 441 | "metadata": { 442 | "id": "qLCInT-I9_-D" 443 | }, 444 | "outputs": [], 445 | "source": [ 446 | "# Transform all columns into float64\n", 447 | "for i in list(attack):\n", 448 | " attack[i]=attack[i].apply(lambda x: str(x).replace(\",\" , \".\"))\n", 449 | "attack = attack.astype(float)" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": { 455 | "id": "c4cB4v3N-Dhu" 456 | }, 457 | "source": [ 458 | "#### Normalization" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 13, 464 | "metadata": { 465 | "id": "jZrha9cO-BGK" 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "from sklearn import preprocessing\n", 470 | "\n", 471 | "x = attack.values \n", 472 | "x_scaled = min_max_scaler.transform(x)\n", 473 | "attack = pd.DataFrame(x_scaled)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 14, 479 | "metadata": { 480 | "colab": { 481 | "base_uri": "https://localhost:8080/", 482 | "height": 126 483 | }, 484 | "id": "z9SwiPco-BUa", 485 | "outputId": "f2507282-c0f9-4253-ece7-0a802b68240f" 486 | }, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "text/html": [ 491 | "
\n", 492 | "\n", 505 | "\n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | "
0123456789...41424344454647484950
00.8841440.5771331.01.00.00.4961580.1888450.0640880.9828991.0...0.9709031.00.00.9461250.4497820.9441160.0000730.00.00.0
10.8911450.5771901.01.00.00.4961580.1888450.0640880.9828991.0...0.9709031.00.00.9461250.4497820.9445210.0000730.00.00.0
\n", 583 | "

2 rows × 51 columns

\n", 584 | "
" 585 | ], 586 | "text/plain": [ 587 | " 0 1 2 3 4 5 6 7 8 \\\n", 588 | "0 0.884144 0.577133 1.0 1.0 0.0 0.496158 0.188845 0.064088 0.982899 \n", 589 | "1 0.891145 0.577190 1.0 1.0 0.0 0.496158 0.188845 0.064088 0.982899 \n", 590 | "\n", 591 | " 9 ... 41 42 43 44 45 46 47 48 \\\n", 592 | "0 1.0 ... 0.970903 1.0 0.0 0.946125 0.449782 0.944116 0.000073 0.0 \n", 593 | "1 1.0 ... 0.970903 1.0 0.0 0.946125 0.449782 0.944521 0.000073 0.0 \n", 594 | "\n", 595 | " 49 50 \n", 596 | "0 0.0 0.0 \n", 597 | "1 0.0 0.0 \n", 598 | "\n", 599 | "[2 rows x 51 columns]" 600 | ] 601 | }, 602 | "execution_count": 14, 603 | "metadata": {}, 604 | "output_type": "execute_result" 605 | } 606 | ], 607 | "source": [ 608 | "attack.head(2)" 609 | ] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": { 614 | "id": "xXJi503b-j_d" 615 | }, 616 | "source": [ 617 | "### Windows" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 15, 623 | "metadata": { 624 | "id": "vyplttZa-BRN" 625 | }, 626 | "outputs": [], 627 | "source": [ 628 | "window_size=12" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 16, 634 | "metadata": { 635 | "colab": { 636 | "base_uri": "https://localhost:8080/", 637 | "height": 34 638 | }, 639 | "id": "dzGJMp6Y-BN5", 640 | "outputId": "2949d278-1313-442c-f06b-275a8c6c6578" 641 | }, 642 | "outputs": [ 643 | { 644 | "data": { 645 | "text/plain": [ 646 | "(494988, 12, 51)" 647 | ] 648 | }, 649 | "execution_count": 16, 650 | "metadata": {}, 651 | "output_type": "execute_result" 652 | } 653 | ], 654 | "source": [ 655 | "windows_normal=normal.values[np.arange(window_size)[None, :] + np.arange(normal.shape[0]-window_size)[:, None]]\n", 656 | "windows_normal.shape" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 17, 662 | "metadata": { 663 | "colab": { 664 | "base_uri": "https://localhost:8080/", 665 | "height": 34 666 | }, 667 | "id": "17LdB3c8-pRH", 668 | "outputId": "721059d4-5937-4dd3-d73c-e5d255fc273c" 669 | }, 670 | "outputs": [ 671 | { 672 | "data": { 673 | "text/plain": [ 674 | "(449907, 12, 51)" 675 | ] 676 | }, 677 | "execution_count": 17, 678 | "metadata": {}, 679 | "output_type": "execute_result" 680 | } 681 | ], 682 | "source": [ 683 | "windows_attack=attack.values[np.arange(window_size)[None, :] + np.arange(attack.shape[0]-window_size)[:, None]]\n", 684 | "windows_attack.shape" 685 | ] 686 | }, 687 | { 688 | "cell_type": "markdown", 689 | "metadata": { 690 | "id": "k70ZFxGs-_7m" 691 | }, 692 | "source": [ 693 | "## Training" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 18, 699 | "metadata": { 700 | "id": "yi9S0SGnDKNc" 701 | }, 702 | "outputs": [], 703 | "source": [ 704 | "import torch.utils.data as data_utils\n", 705 | "\n", 706 | "BATCH_SIZE = 7919\n", 707 | "N_EPOCHS = 100\n", 708 | "hidden_size = 100\n", 709 | "\n", 710 | "w_size=windows_normal.shape[1]*windows_normal.shape[2]\n", 711 | "z_size=windows_normal.shape[1]*hidden_size\n", 712 | "\n", 713 | "windows_normal_train = windows_normal[:int(np.floor(.8 * windows_normal.shape[0]))]\n", 714 | "windows_normal_val = windows_normal[int(np.floor(.8 * windows_normal.shape[0])):int(np.floor(windows_normal.shape[0]))]\n", 715 | "\n", 716 | "train_loader = torch.utils.data.DataLoader(data_utils.TensorDataset(\n", 717 | " torch.from_numpy(windows_normal_train).float().view(([windows_normal_train.shape[0],w_size]))\n", 718 | ") , batch_size=BATCH_SIZE, shuffle=False, num_workers=0)\n", 719 | "\n", 720 | "val_loader = torch.utils.data.DataLoader(data_utils.TensorDataset(\n", 721 | " torch.from_numpy(windows_normal_val).float().view(([windows_normal_val.shape[0],w_size]))\n", 722 | ") , batch_size=BATCH_SIZE, shuffle=False, num_workers=0)\n", 723 | "\n", 724 | "test_loader = torch.utils.data.DataLoader(data_utils.TensorDataset(\n", 725 | " torch.from_numpy(windows_attack).float().view(([windows_attack.shape[0],w_size]))\n", 726 | ") , batch_size=BATCH_SIZE, shuffle=False, num_workers=0)\n", 727 | "\n", 728 | "model = UsadModel(w_size, z_size)\n", 729 | "model = to_device(model,device)" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": 19, 735 | "metadata": { 736 | "colab": { 737 | "base_uri": "https://localhost:8080/", 738 | "height": 1000 739 | }, 740 | "id": "So9yjDPEDObC", 741 | "outputId": "629bcd13-37b1-4907-ef0d-46d9e3ad5398", 742 | "scrolled": true 743 | }, 744 | "outputs": [ 745 | { 746 | "name": "stdout", 747 | "output_type": "stream", 748 | "text": [ 749 | "Epoch [0], val_loss1: 0.0261, val_loss2: 0.0263\n", 750 | "Epoch [1], val_loss1: 0.0206, val_loss2: -0.0026\n", 751 | "Epoch [2], val_loss1: 0.0323, val_loss2: -0.0210\n", 752 | "Epoch [3], val_loss1: 0.0389, val_loss2: -0.0281\n", 753 | "Epoch [4], val_loss1: 0.0423, val_loss2: -0.0310\n", 754 | "Epoch [5], val_loss1: 0.0293, val_loss2: -0.0217\n", 755 | "Epoch [6], val_loss1: 0.0593, val_loss2: -0.0476\n", 756 | "Epoch [7], val_loss1: 0.0582, val_loss2: -0.0514\n", 757 | "Epoch [8], val_loss1: 0.0588, val_loss2: -0.0537\n", 758 | "Epoch [9], val_loss1: 0.0592, val_loss2: -0.0546\n", 759 | "Epoch [10], val_loss1: 0.0595, val_loss2: -0.0554\n", 760 | "Epoch [11], val_loss1: 0.0607, val_loss2: -0.0570\n", 761 | "Epoch [12], val_loss1: 0.0629, val_loss2: -0.0594\n", 762 | "Epoch [13], val_loss1: 0.0632, val_loss2: -0.0600\n", 763 | "Epoch [14], val_loss1: 0.1583, val_loss2: -0.1409\n", 764 | "Epoch [15], val_loss1: 0.1570, val_loss2: -0.1452\n", 765 | "Epoch [16], val_loss1: 0.1564, val_loss2: -0.1460\n", 766 | "Epoch [17], val_loss1: 0.1568, val_loss2: -0.1470\n", 767 | "Epoch [18], val_loss1: 0.1573, val_loss2: -0.1481\n", 768 | "Epoch [19], val_loss1: 0.1577, val_loss2: -0.1490\n", 769 | "Epoch [20], val_loss1: 0.1584, val_loss2: -0.1500\n", 770 | "Epoch [21], val_loss1: 0.1581, val_loss2: -0.1502\n", 771 | "Epoch [22], val_loss1: 0.1592, val_loss2: -0.1515\n", 772 | "Epoch [23], val_loss1: 0.1612, val_loss2: -0.1538\n", 773 | "Epoch [24], val_loss1: 0.1612, val_loss2: -0.1541\n", 774 | "Epoch [25], val_loss1: 0.1636, val_loss2: -0.1568\n", 775 | "Epoch [26], val_loss1: 0.1631, val_loss2: -0.1565\n", 776 | "Epoch [27], val_loss1: 0.1631, val_loss2: -0.1567\n", 777 | "Epoch [28], val_loss1: 0.1630, val_loss2: -0.1569\n", 778 | "Epoch [29], val_loss1: 0.1645, val_loss2: -0.1585\n", 779 | "Epoch [30], val_loss1: 0.1655, val_loss2: -0.1597\n", 780 | "Epoch [31], val_loss1: 0.1657, val_loss2: -0.1600\n", 781 | "Epoch [32], val_loss1: 0.1666, val_loss2: -0.1611\n", 782 | "Epoch [33], val_loss1: 0.1672, val_loss2: -0.1618\n", 783 | "Epoch [34], val_loss1: 0.1670, val_loss2: -0.1618\n", 784 | "Epoch [35], val_loss1: 0.1672, val_loss2: -0.1621\n", 785 | "Epoch [36], val_loss1: 0.1672, val_loss2: -0.1623\n", 786 | "Epoch [37], val_loss1: 0.1673, val_loss2: -0.1625\n", 787 | "Epoch [38], val_loss1: 0.1675, val_loss2: -0.1629\n", 788 | "Epoch [39], val_loss1: 0.1683, val_loss2: -0.1637\n", 789 | "Epoch [40], val_loss1: 0.1685, val_loss2: -0.1640\n", 790 | "Epoch [41], val_loss1: 0.1687, val_loss2: -0.1644\n", 791 | "Epoch [42], val_loss1: 0.1689, val_loss2: -0.1647\n", 792 | "Epoch [43], val_loss1: 0.1694, val_loss2: -0.1653\n", 793 | "Epoch [44], val_loss1: 0.1696, val_loss2: -0.1655\n", 794 | "Epoch [45], val_loss1: 0.1696, val_loss2: -0.1656\n", 795 | "Epoch [46], val_loss1: 0.1697, val_loss2: -0.1658\n", 796 | "Epoch [47], val_loss1: 0.1698, val_loss2: -0.1660\n", 797 | "Epoch [48], val_loss1: 0.1698, val_loss2: -0.1661\n", 798 | "Epoch [49], val_loss1: 0.1699, val_loss2: -0.1662\n", 799 | "Epoch [50], val_loss1: 0.1699, val_loss2: -0.1663\n", 800 | "Epoch [51], val_loss1: 0.1700, val_loss2: -0.1665\n", 801 | "Epoch [52], val_loss1: 0.1700, val_loss2: -0.1666\n", 802 | "Epoch [53], val_loss1: 0.1701, val_loss2: -0.1667\n", 803 | "Epoch [54], val_loss1: 0.1702, val_loss2: -0.1668\n", 804 | "Epoch [55], val_loss1: 0.1702, val_loss2: -0.1670\n", 805 | "Epoch [56], val_loss1: 0.1703, val_loss2: -0.1671\n", 806 | "Epoch [57], val_loss1: 0.1703, val_loss2: -0.1672\n", 807 | "Epoch [58], val_loss1: 0.1704, val_loss2: -0.1673\n", 808 | "Epoch [59], val_loss1: 0.1704, val_loss2: -0.1674\n", 809 | "Epoch [60], val_loss1: 0.1705, val_loss2: -0.1675\n", 810 | "Epoch [61], val_loss1: 0.1705, val_loss2: -0.1676\n", 811 | "Epoch [62], val_loss1: 0.1705, val_loss2: -0.1676\n", 812 | "Epoch [63], val_loss1: 0.1706, val_loss2: -0.1677\n", 813 | "Epoch [64], val_loss1: 0.1706, val_loss2: -0.1678\n", 814 | "Epoch [65], val_loss1: 0.1707, val_loss2: -0.1679\n", 815 | "Epoch [66], val_loss1: 0.1707, val_loss2: -0.1680\n", 816 | "Epoch [67], val_loss1: 0.1707, val_loss2: -0.1680\n", 817 | "Epoch [68], val_loss1: 0.1708, val_loss2: -0.1681\n", 818 | "Epoch [69], val_loss1: 0.1708, val_loss2: -0.1682\n", 819 | "Epoch [70], val_loss1: 0.1708, val_loss2: -0.1683\n", 820 | "Epoch [71], val_loss1: 0.1709, val_loss2: -0.1683\n", 821 | "Epoch [72], val_loss1: 0.1709, val_loss2: -0.1684\n", 822 | "Epoch [73], val_loss1: 0.1709, val_loss2: -0.1685\n", 823 | "Epoch [74], val_loss1: 0.1710, val_loss2: -0.1685\n", 824 | "Epoch [75], val_loss1: 0.1710, val_loss2: -0.1686\n", 825 | "Epoch [76], val_loss1: 0.1710, val_loss2: -0.1686\n", 826 | "Epoch [77], val_loss1: 0.1710, val_loss2: -0.1687\n", 827 | "Epoch [78], val_loss1: 0.1711, val_loss2: -0.1688\n", 828 | "Epoch [79], val_loss1: 0.1711, val_loss2: -0.1688\n", 829 | "Epoch [80], val_loss1: 0.1711, val_loss2: -0.1689\n", 830 | "Epoch [81], val_loss1: 0.1711, val_loss2: -0.1689\n", 831 | "Epoch [82], val_loss1: 0.1712, val_loss2: -0.1690\n", 832 | "Epoch [83], val_loss1: 0.1712, val_loss2: -0.1690\n", 833 | "Epoch [84], val_loss1: 0.1712, val_loss2: -0.1691\n", 834 | "Epoch [85], val_loss1: 0.1712, val_loss2: -0.1691\n", 835 | "Epoch [86], val_loss1: 0.1713, val_loss2: -0.1692\n", 836 | "Epoch [87], val_loss1: 0.1713, val_loss2: -0.1692\n", 837 | "Epoch [88], val_loss1: 0.1713, val_loss2: -0.1692\n", 838 | "Epoch [89], val_loss1: 0.1713, val_loss2: -0.1693\n", 839 | "Epoch [90], val_loss1: 0.1713, val_loss2: -0.1693\n", 840 | "Epoch [91], val_loss1: 0.1714, val_loss2: -0.1694\n", 841 | "Epoch [92], val_loss1: 0.1714, val_loss2: -0.1694\n", 842 | "Epoch [93], val_loss1: 0.1714, val_loss2: -0.1695\n", 843 | "Epoch [94], val_loss1: 0.1714, val_loss2: -0.1695\n", 844 | "Epoch [95], val_loss1: 0.1714, val_loss2: -0.1695\n", 845 | "Epoch [96], val_loss1: 0.1715, val_loss2: -0.1696\n", 846 | "Epoch [97], val_loss1: 0.1715, val_loss2: -0.1696\n", 847 | "Epoch [98], val_loss1: 0.1715, val_loss2: -0.1696\n", 848 | "Epoch [99], val_loss1: 0.1715, val_loss2: -0.1697\n" 849 | ] 850 | } 851 | ], 852 | "source": [ 853 | "history = training(N_EPOCHS,model,train_loader,val_loader)" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": 20, 859 | "metadata": { 860 | "colab": { 861 | "base_uri": "https://localhost:8080/", 862 | "height": 295 863 | }, 864 | "id": "fYwlN0JKVVtN", 865 | "outputId": "c742ff8b-3b4a-41f5-dd09-effee1be928a" 866 | }, 867 | "outputs": [ 868 | { 869 | "data": { 870 | "image/png": "\n", 871 | "text/plain": [ 872 | "
" 873 | ] 874 | }, 875 | "metadata": { 876 | "needs_background": "light" 877 | }, 878 | "output_type": "display_data" 879 | } 880 | ], 881 | "source": [ 882 | "plot_history(history)" 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": 21, 888 | "metadata": { 889 | "id": "ieObNqKYsOzh" 890 | }, 891 | "outputs": [], 892 | "source": [ 893 | "torch.save({\n", 894 | " 'encoder': model.encoder.state_dict(),\n", 895 | " 'decoder1': model.decoder1.state_dict(),\n", 896 | " 'decoder2': model.decoder2.state_dict()\n", 897 | " }, \"model.pth\")" 898 | ] 899 | }, 900 | { 901 | "cell_type": "markdown", 902 | "metadata": { 903 | "id": "ymhjbmvR_DgJ" 904 | }, 905 | "source": [ 906 | "## Testing" 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "execution_count": 22, 912 | "metadata": { 913 | "colab": { 914 | "base_uri": "https://localhost:8080/", 915 | "height": 34 916 | }, 917 | "id": "b7rbm9wdXKeF", 918 | "outputId": "076309c7-22be-41f6-f916-5f11cb679672" 919 | }, 920 | "outputs": [ 921 | { 922 | "data": { 923 | "text/plain": [ 924 | "" 925 | ] 926 | }, 927 | "execution_count": 22, 928 | "metadata": {}, 929 | "output_type": "execute_result" 930 | } 931 | ], 932 | "source": [ 933 | "checkpoint = torch.load(\"model.pth\")\n", 934 | "\n", 935 | "model.encoder.load_state_dict(checkpoint['encoder'])\n", 936 | "model.decoder1.load_state_dict(checkpoint['decoder1'])\n", 937 | "model.decoder2.load_state_dict(checkpoint['decoder2'])" 938 | ] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "execution_count": 23, 943 | "metadata": { 944 | "id": "Ry1QTp6V2ny4" 945 | }, 946 | "outputs": [], 947 | "source": [ 948 | "results=testing(model,test_loader)" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 24, 954 | "metadata": {}, 955 | "outputs": [], 956 | "source": [ 957 | "windows_labels=[]\n", 958 | "for i in range(len(labels)-window_size):\n", 959 | " windows_labels.append(list(np.int_(labels[i:i+window_size])))" 960 | ] 961 | }, 962 | { 963 | "cell_type": "code", 964 | "execution_count": 25, 965 | "metadata": {}, 966 | "outputs": [], 967 | "source": [ 968 | "y_test = [1.0 if (np.sum(window) > 0) else 0 for window in windows_labels ]" 969 | ] 970 | }, 971 | { 972 | "cell_type": "code", 973 | "execution_count": 26, 974 | "metadata": { 975 | "id": "FSWwxheNvxR7" 976 | }, 977 | "outputs": [], 978 | "source": [ 979 | "y_pred=np.concatenate([torch.stack(results[:-1]).flatten().detach().cpu().numpy(),\n", 980 | " results[-1].flatten().detach().cpu().numpy()])" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": 27, 986 | "metadata": { 987 | "colab": { 988 | "base_uri": "https://localhost:8080/", 989 | "height": 279 990 | }, 991 | "id": "bROUyLM93cG3", 992 | "outputId": "755359d9-d0fb-4deb-b313-d3c2a2465a26" 993 | }, 994 | "outputs": [ 995 | { 996 | "data": { 997 | "image/png": "\n", 998 | "text/plain": [ 999 | "
" 1000 | ] 1001 | }, 1002 | "metadata": { 1003 | "needs_background": "light" 1004 | }, 1005 | "output_type": "display_data" 1006 | } 1007 | ], 1008 | "source": [ 1009 | "threshold=ROC(y_test,y_pred)" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "execution_count": null, 1015 | "metadata": {}, 1016 | "outputs": [], 1017 | "source": [] 1018 | } 1019 | ], 1020 | "metadata": { 1021 | "accelerator": "GPU", 1022 | "colab": { 1023 | "name": "USAD_test.ipynb", 1024 | "provenance": [], 1025 | "toc_visible": true 1026 | }, 1027 | "kernelspec": { 1028 | "display_name": "Python 3", 1029 | "language": "python", 1030 | "name": "python3" 1031 | }, 1032 | "language_info": { 1033 | "codemirror_mode": { 1034 | "name": "ipython", 1035 | "version": 3 1036 | }, 1037 | "file_extension": ".py", 1038 | "mimetype": "text/x-python", 1039 | "name": "python", 1040 | "nbconvert_exporter": "python", 1041 | "pygments_lexer": "ipython3", 1042 | "version": "3.6.8" 1043 | } 1044 | }, 1045 | "nbformat": 4, 1046 | "nbformat_minor": 1 1047 | } 1048 | --------------------------------------------------------------------------------