├── .gitignore ├── README.md ├── checkpoint └── eval │ ├── traffic │ ├── graph_best.pt │ └── traffic_best.pt │ └── water │ ├── GANF_water_seed_18_best.pt │ └── graph_best.pt ├── dataset.py ├── eval_water.py ├── example_baseline └── train_SVDD_water.py ├── log └── water.log ├── models ├── DROCC.py ├── DeepSAD.py ├── GAN.py ├── GANF.py ├── NF.py ├── RNN.py ├── __pycache__ │ ├── DROCC.cpython-38.pyc │ ├── DeepSAD.cpython-36.pyc │ ├── DeepSAD.cpython-38.pyc │ ├── GAN.cpython-38.pyc │ ├── GANF.cpython-38.pyc │ ├── GDN.cpython-36.pyc │ ├── NF.cpython-36.pyc │ ├── NF.cpython-38.pyc │ ├── PMUNF.cpython-36.pyc │ ├── PMUNF.cpython-38.pyc │ ├── RNN.cpython-36.pyc │ ├── RNN.cpython-38.pyc │ └── graph_layer.cpython-36.pyc └── graph_layer.py ├── train_traffic.py ├── train_water.py ├── train_water.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /checkpoint/model 2 | /data/* 3 | **pycache** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # GANF 3 | Offical implementation of "Graph-Augmented Normalizing Flows for Anomaly Detection of Multiple Time Series" (ICLR 2022). [[paper]](https://openreview.net/pdf?id=45L_dgP48Vd) 4 | 5 | ## Requirements 6 | 7 | ``` 8 | torch==1.7.1 9 | ``` 10 | 11 | ## Overview 12 | 13 | * `./models`: This directory includes the code of GANF as well as basline methods. 14 | * `./checkpoint`: This directory stores the trained models. The trained models for the datasets **SWaT** and **Metr-LA** are given in `./checkpoint/eval`. 15 | * `./train_water.py` and `./train_traffic.py`: These programs are used to train GANF on the corresponding datasets. 16 | * `./data`: This directory is used to store the datasets. 17 | 18 | 19 | ## Datasets 20 | The paper uses three datasets for experiments: 21 | * **SWaT**: This water system dataset can be requested from [iTrust](https://itrust.sutd.edu.sg/). We utilze the attack_v0 data in Dec/2015 for experimentation. You may need to first convert the file format to .csv to use our code. Then, use `./dataset.py` to perform train/val/test split. 22 | * **Metr-LA**: This traffic dataset does not include ground-truth outliers. It can be used for exploratory studies of density estimation. The dataset can be downloaded from [this GitHub](https://github.com/liyaguang/DCRNN). 23 | * **PMU**: This power grid dataset is proprietary and we are unable to offer it for public use. 24 | 25 | ## Experiments 26 | To train a GANF model on **SWaT**, run the bash script: 27 | ``` 28 | bash train_water.sh 29 | ``` 30 | The training log will be located at `./log` as a reference to reproduce the results in the paper. 31 | 32 | We also provide trained models in `./checkpoint/eval` for evaluation. You can call: 33 | ``` 34 | python eval_water.py 35 | ``` 36 | 37 | To train a GANF model on **Metr-LA**, run: 38 | ``` 39 | python train_traffic.py 40 | ``` 41 | 42 | ## Citation 43 | If you find this repo useful, please cite the paper. Thank you! 44 | ``` 45 | @inproceedings{ 46 | dai2022graphaugmented, 47 | title={Graph-Augmented Normalizing Flows for Anomaly Detection of Multiple Time Series}, 48 | author={Enyan Dai and Jie Chen}, 49 | booktitle={International Conference on Learning Representations}, 50 | year={2022}, 51 | url={https://openreview.net/forum?id=45L_dgP48Vd} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /checkpoint/eval/traffic/graph_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/checkpoint/eval/traffic/graph_best.pt -------------------------------------------------------------------------------- /checkpoint/eval/traffic/traffic_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/checkpoint/eval/traffic/traffic_best.pt -------------------------------------------------------------------------------- /checkpoint/eval/water/GANF_water_seed_18_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/checkpoint/eval/water/GANF_water_seed_18_best.pt -------------------------------------------------------------------------------- /checkpoint/eval/water/graph_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/checkpoint/eval/water/graph_best.pt -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import pandas as pd 3 | import torch 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | 7 | # %% 8 | from torch.utils.data import DataLoader 9 | def load_traffic(root, batch_size): 10 | """ 11 | Load traffic dataset 12 | return train_loader, val_loader, test_loader 13 | """ 14 | df = pd.read_hdf(root) 15 | df = df.reset_index() 16 | df = df.rename(columns={"index":"utc"}) 17 | df["utc"] = pd.to_datetime(df["utc"], unit="s") 18 | df = df.set_index("utc") 19 | n_sensor = len(df.columns) 20 | 21 | mean = df.values.flatten().mean() 22 | std = df.values.flatten().std() 23 | 24 | df = (df - mean)/std 25 | df = df.sort_index() 26 | # split the dataset 27 | train_df = df.iloc[:int(0.75*len(df))] 28 | val_df = df.iloc[int(0.75*len(df)):int(0.875*len(df))] 29 | test_df = df.iloc[int(0.75*len(df)):] 30 | 31 | train_loader = DataLoader(Traffic(train_df), batch_size=batch_size, shuffle=True) 32 | val_loader = DataLoader(Traffic(val_df), batch_size=batch_size, shuffle=False) 33 | test_loader = DataLoader(Traffic(test_df), batch_size=batch_size, shuffle=False) 34 | 35 | return train_loader, val_loader, test_loader, n_sensor 36 | 37 | class Traffic(Dataset): 38 | def __init__(self, df, window_size=12, stride_size=1): 39 | super(Traffic, self).__init__() 40 | self.df = df 41 | self.window_size = window_size 42 | self.stride_size = stride_size 43 | 44 | self.data, self.idx, self.time = self.preprocess(df) 45 | 46 | def preprocess(self, df): 47 | 48 | start_idx = np.arange(0,len(df)-self.window_size,self.stride_size) 49 | end_idx = np.arange(self.window_size, len(df), self.stride_size) 50 | 51 | delat_time = df.index[end_idx]-df.index[start_idx] 52 | idx_mask = delat_time==pd.Timedelta(5*self.window_size,unit='min') 53 | 54 | return df.values, start_idx[idx_mask], df.index[start_idx[idx_mask]] 55 | 56 | def __len__(self): 57 | 58 | length = len(self.idx) 59 | 60 | return length 61 | 62 | def __getitem__(self, index): 63 | # N X K X L X D 64 | start = self.idx[index] 65 | end = start + self.window_size 66 | data = self.data[start:end].reshape([self.window_size,-1, 1]) 67 | 68 | return torch.FloatTensor(data).transpose(0,1) 69 | 70 | def load_water(root, batch_size,label=False): 71 | 72 | data = pd.read_csv(root) 73 | data = data.rename(columns={"Normal/Attack":"label"}) 74 | data.label[data.label!="Normal"]=1 75 | data.label[data.label=="Normal"]=0 76 | data["Timestamp"] = pd.to_datetime(data["Timestamp"]) 77 | data = data.set_index("Timestamp") 78 | 79 | #%% 80 | feature = data.iloc[:,:51] 81 | mean_df = feature.mean(axis=0) 82 | std_df = feature.std(axis=0) 83 | 84 | norm_feature = (feature-mean_df)/std_df 85 | norm_feature = norm_feature.dropna(axis=1) 86 | n_sensor = len(norm_feature.columns) 87 | 88 | train_df = norm_feature.iloc[:int(0.6*len(data))] 89 | train_label = data.label.iloc[:int(0.6*len(data))] 90 | 91 | val_df = norm_feature.iloc[int(0.6*len(data)):int(0.8*len(data))] 92 | val_label = data.label.iloc[int(0.6*len(data)):int(0.8*len(data))] 93 | 94 | test_df = norm_feature.iloc[int(0.8*len(data)):] 95 | test_label = data.label.iloc[int(0.8*len(data)):] 96 | if label: 97 | train_loader = DataLoader(WaterLabel(train_df,train_label), batch_size=batch_size, shuffle=True) 98 | else: 99 | train_loader = DataLoader(Water(train_df,train_label), batch_size=batch_size, shuffle=True) 100 | val_loader = DataLoader(Water(val_df,val_label), batch_size=batch_size, shuffle=False) 101 | test_loader = DataLoader(Water(test_df,test_label), batch_size=batch_size, shuffle=False) 102 | 103 | return train_loader, val_loader, test_loader, n_sensor 104 | 105 | class Water(Dataset): 106 | def __init__(self, df, label, window_size=60, stride_size=10): 107 | super(Water, self).__init__() 108 | self.df = df 109 | self.window_size = window_size 110 | self.stride_size = stride_size 111 | 112 | self.data, self.idx, self.label = self.preprocess(df,label) 113 | 114 | def preprocess(self, df, label): 115 | 116 | start_idx = np.arange(0,len(df)-self.window_size,self.stride_size) 117 | end_idx = np.arange(self.window_size, len(df), self.stride_size) 118 | 119 | delat_time = df.index[end_idx]-df.index[start_idx] 120 | idx_mask = delat_time==pd.Timedelta(self.window_size,unit='s') 121 | 122 | return df.values, start_idx[idx_mask], label[start_idx[idx_mask]] 123 | 124 | def __len__(self): 125 | 126 | length = len(self.idx) 127 | 128 | return length 129 | 130 | def __getitem__(self, index): 131 | # N X K X L X D 132 | start = self.idx[index] 133 | end = start + self.window_size 134 | data = self.data[start:end].reshape([self.window_size,-1, 1]) 135 | 136 | return torch.FloatTensor(data).transpose(0,1) 137 | 138 | 139 | class WaterLabel(Dataset): 140 | def __init__(self, df, label, window_size=60, stride_size=10): 141 | super(WaterLabel, self).__init__() 142 | self.df = df 143 | self.window_size = window_size 144 | self.stride_size = stride_size 145 | 146 | self.data, self.idx, self.label = self.preprocess(df,label) 147 | self.label = 1.0-2*self.label 148 | 149 | def preprocess(self, df, label): 150 | 151 | start_idx = np.arange(0,len(df)-self.window_size,self.stride_size) 152 | end_idx = np.arange(self.window_size, len(df), self.stride_size) 153 | 154 | delat_time = df.index[end_idx]-df.index[start_idx] 155 | idx_mask = delat_time==pd.Timedelta(self.window_size,unit='s') 156 | 157 | return df.values, start_idx[idx_mask], label[start_idx[idx_mask]] 158 | 159 | def __len__(self): 160 | 161 | length = len(self.idx) 162 | 163 | return length 164 | 165 | def __getitem__(self, index): 166 | # N X K X L X D 167 | start = self.idx[index] 168 | end = start + self.window_size 169 | data = self.data[start:end].reshape([self.window_size,-1, 1]) 170 | 171 | return torch.FloatTensor(data).transpose(0,1),self.label[index] -------------------------------------------------------------------------------- /eval_water.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import argparse 4 | import torch 5 | from models.GANF import GANF 6 | import numpy as np 7 | from sklearn.metrics import roc_auc_score 8 | # from data import fetch_dataloaders 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | # files 13 | parser.add_argument('--data_dir', type=str, 14 | default='./data/SWaT_Dataset_Attack_v0.csv', help='Location of datasets.') 15 | parser.add_argument('--output_dir', type=str, 16 | default='/home/enyandai/code/checkpoint/model') 17 | parser.add_argument('--name',default='GANF_Water') 18 | # restore 19 | parser.add_argument('--graph', type=str, default='None') 20 | parser.add_argument('--model', type=str, default='None') 21 | parser.add_argument('--seed', type=int, default=10, help='Random seed to use.') 22 | # made parameters 23 | parser.add_argument('--n_blocks', type=int, default=1, help='Number of blocks to stack in a model (MADE in MAF; Coupling+BN in RealNVP).') 24 | parser.add_argument('--n_components', type=int, default=1, help='Number of Gaussian clusters for mixture of gaussians models.') 25 | parser.add_argument('--hidden_size', type=int, default=32, help='Hidden layer size for MADE (and each MADE block in an MAF).') 26 | parser.add_argument('--n_hidden', type=int, default=1, help='Number of hidden layers in each MADE.') 27 | parser.add_argument('--batch_norm', type=bool, default=False) 28 | # training params 29 | parser.add_argument('--batch_size', type=int, default=512) 30 | 31 | args = parser.parse_known_args()[0] 32 | args.cuda = torch.cuda.is_available() 33 | device = torch.device("cuda" if args.cuda else "cpu") 34 | 35 | 36 | print(args) 37 | import random 38 | import numpy as np 39 | random.seed(args.seed) 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | if args.cuda: 43 | torch.cuda.manual_seed(args.seed) 44 | #%% 45 | print("Loading dataset") 46 | from dataset import load_water 47 | 48 | train_loader, val_loader, test_loader, n_sensor = load_water(args.data_dir, \ 49 | args.batch_size) 50 | 51 | #%% 52 | model = GANF(args.n_blocks, 1, args.hidden_size, args.n_hidden, dropout=0.0, batch_norm=args.batch_norm) 53 | model = model.to(device) 54 | 55 | 56 | model.load_state_dict(torch.load("./checkpoint/eval/water/GANF_water_seed_18_best.pt")) 57 | A = torch.load("./checkpoint/eval/GANF_water_seed_18/graph_best.pt").to(device) 58 | model.eval() 59 | #%% 60 | loss_test = [] 61 | with torch.no_grad(): 62 | for x in test_loader: 63 | 64 | x = x.to(device) 65 | loss = -model.test(x, A.data).cpu().numpy() 66 | loss_test.append(loss) 67 | loss_test = np.concatenate(loss_test) 68 | roc_test = roc_auc_score(np.asarray(test_loader.dataset.label.values,dtype=int),loss_test) 69 | print("The ROC score on SWaT dataset is {}".format(roc_test)) 70 | # %% 71 | -------------------------------------------------------------------------------- /example_baseline/train_SVDD_water.py: -------------------------------------------------------------------------------- 1 | 2 | #%% 3 | import os 4 | import argparse 5 | import torch 6 | from models.RNN import RecurrentAE 7 | import torch.nn.functional as F 8 | from dataset import PMUTime 9 | import numpy as np 10 | 11 | parser = argparse.ArgumentParser() 12 | # action 13 | 14 | parser.add_argument('--data_dir', type=str, 15 | default='/data', help='Location of datasets.') 16 | parser.add_argument('--output_dir', type=str, 17 | default='/home/enyandai/code/checkpoint/model') 18 | parser.add_argument('--dataset', type=str, default='C') 19 | parser.add_argument('--model', type=str, 20 | default='None') 21 | parser.add_argument('--name',default='SVDD_Water_test') 22 | parser.add_argument('--seed', type=int, default=11, help='Random seed to use.') 23 | # made parameters 24 | parser.add_argument('--hidden_size', type=int, default=64, help='Hidden layer size for MADE (and each MADE block in an MAF).') 25 | parser.add_argument('--n_hidden', type=int, default=1, help='Number of hidden layers in each MADE.') 26 | # training params 27 | parser.add_argument('--batch_size', type=int, default=1024) 28 | parser.add_argument('--weight_decay', type=float, default=5e-4) 29 | parser.add_argument('--n_epochs', type=int, default=50) 30 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.') 31 | 32 | 33 | args = parser.parse_known_args()[0] 34 | args.cuda = torch.cuda.is_available() 35 | device = torch.device("cuda" if args.cuda else "cpu") 36 | 37 | print(args) 38 | import random 39 | import numpy as np 40 | random.seed(args.seed) 41 | np.random.seed(args.seed) 42 | torch.manual_seed(args.seed) 43 | if args.cuda: 44 | torch.cuda.manual_seed(args.seed) 45 | #%% load dataset 46 | print("Loading dataset") 47 | from dataset import load_water 48 | 49 | train_loader, val_loader, test_loader, n_sensor = load_water("/home/enyandai/orginal_code/data/SWaT_Dataset_Attack_v0.csv", \ 50 | args.batch_size) 51 | 52 | # %% 53 | from models.DeepSAD import DeepSVDD 54 | model = DeepSVDD(n_sensor, args.hidden_size, device) 55 | #%% 56 | model.pretrain(train_loader, args, device) 57 | model.train(train_loader, args, device) 58 | #%% 59 | save_path = os.path.join(args.output_dir,args.name) 60 | if not os.path.exists(save_path): 61 | os.makedirs(save_path) 62 | model.save_model(os.path.join(save_path, "{}.pt".format(args.name))) 63 | #%% 64 | # for seed in range(10,21): 65 | # model.load_model("/home/enyandai/orginal_code/checkpoint/model/SVDD_Water/SVDD_Water_39.pt") 66 | model.net.eval() 67 | loss = [] 68 | from sklearn.metrics import roc_auc_score 69 | with torch.no_grad(): 70 | for data in test_loader: 71 | 72 | x = data.to(device) 73 | x = torch.transpose(x, dim0=2, dim1=3) 74 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 75 | outputs = model.net(inputs).squeeze().mean(dim=-1) 76 | batch_loss= torch.sum((outputs - model.c) ** 2, dim=1).cpu().numpy() 77 | loss.append(batch_loss) 78 | loss = np.concatenate(loss) 79 | roc_test = roc_auc_score(np.asarray(test_loader.dataset.label.values,dtype=int),loss) 80 | print("ROC: {:.4f}".format(roc_test)) 81 | 82 | # %% 83 | 84 | # %% 85 | -------------------------------------------------------------------------------- /log/water.log: -------------------------------------------------------------------------------- 1 | Namespace(alpha_init=0.0, batch_norm=False, batch_size=512, cuda=True, data_dir='./data/SWaT_Dataset_Attack_v0.csv', graph='None', h_tol=0.0001, hidden_size=32, lambda1=0.0, log_interval=5, lr=0.002, max_iter=20, model='None', n_blocks=1, n_components=1, n_epochs=1, n_hidden=1, name='GANF_Water', output_dir='/home/enyandai/code/checkpoint/model', rho_init=1.0, rho_max=1e+16, seed=18, weight_decay=0.0005) 2 | Loading dataset 3 | -1.1010821 -2.468518 -1.1290152 -2.4325068 4 | Epoch: 1, train -log_prob: -0.25, test -log_prob: -2.31, roc_val: 0.8767, roc_test: 0.7705 ,h: 0.16984939575195312 5 | rho: 1.0, alpha 0.0, h 0.16984939575195312 6 | =========================================== 7 | -1.4870193 -2.9885864 -1.5630229 -2.9028065 8 | Epoch: 2, train -log_prob: -0.82, test -log_prob: -2.75, roc_val: 0.8833, roc_test: 0.7652 ,h: 0.1089019775390625 9 | rho: 1.0, alpha 0.16984939575195312, h 0.1089019775390625 10 | =========================================== 11 | -1.5830858 -2.9737263 -1.5540357 -2.92044 12 | Epoch: 3, train -log_prob: 2.17, test -log_prob: -2.76, roc_val: 0.8834, roc_test: 0.7689 ,h: 0.06237030029296875 13 | rho: 10.0, alpha 0.16984939575195312, h 0.06237030029296875 14 | =========================================== 15 | -2.048981 -3.4452784 -2.0425334 -3.3688452 16 | Epoch: 4, train -log_prob: 0.02, test -log_prob: -3.21, roc_val: 0.8634, roc_test: 0.7636 ,h: 0.03321075439453125 17 | rho: 10.0, alpha 0.7935523986816406, h 0.03321075439453125 18 | =========================================== 19 | -1.3401155 -2.4521708 -1.3767326 -2.460982 20 | Epoch: 5, train -log_prob: 2.47, test -log_prob: -2.28, roc_val: 0.7656, roc_test: 0.7363 ,h: 0.0155487060546875 21 | rho: 100.0, alpha 0.7935523986816406, h 0.0155487060546875 22 | =========================================== 23 | -2.1675432 -3.3567538 -2.1354437 -3.3120098 24 | Epoch: 6, train -log_prob: 0.12, test -log_prob: -3.17, roc_val: 0.8741, roc_test: 0.7651 ,h: 0.0060272216796875 25 | rho: 100.0, alpha 2.3484230041503906, h 0.0060272216796875 26 | =========================================== 27 | -2.0511203 -3.1093624 -2.0011985 -3.0835102 28 | Epoch: 7, train -log_prob: 0.98, test -log_prob: -2.96, roc_val: 0.8751, roc_test: 0.7656 ,h: 0.002109527587890625 29 | rho: 100.0, alpha 2.9511451721191406, h 0.002109527587890625 30 | =========================================== 31 | -2.174802 -3.165614 -2.1099446 -3.150142 32 | Epoch: 8, train -log_prob: -0.07, test -log_prob: -3.03, roc_val: 0.8719, roc_test: 0.7632 ,h: 0.000568389892578125 33 | rho: 100.0, alpha 3.162097930908203, h 0.000568389892578125 34 | =========================================== 35 | -1.418484 -2.242597 -1.40083 -2.2324018 36 | Epoch: 9, train -log_prob: 0.40, test -log_prob: -2.15, roc_val: 0.8671, roc_test: 0.8183 ,h: 0.00011444091796875 37 | rho: 100.0, alpha 3.2189369201660156, h 0.00011444091796875 38 | =========================================== 39 | -2.4398491 -3.548349 -2.4290686 -3.5319073 40 | Epoch: 10, train -log_prob: -1.07, test -log_prob: -3.40, roc_val: 0.8552, roc_test: 0.7519 ,h: 0.0001220703125 41 | rho: 100.0, alpha 3.2303810119628906, h 0.0001220703125 42 | =========================================== 43 | -2.0225897 -2.9151032 -1.9856229 -2.8960047 44 | Epoch: 11, train -log_prob: 2.36, test -log_prob: -2.79, roc_val: 0.8721, roc_test: 0.7484 ,h: 0.000942230224609375 45 | rho: 1000.0, alpha 3.2303810119628906, h 0.000942230224609375 46 | =========================================== 47 | -1.9903402 -2.881948 -1.9751995 -2.86814 48 | Epoch: 12, train -log_prob: -0.75, test -log_prob: -2.76, roc_val: 0.8644, roc_test: 0.7447 ,h: 0.00052642822265625 49 | rho: 10000.0, alpha 3.2303810119628906, h 0.00052642822265625 50 | =========================================== 51 | -2.260918 -3.279379 -2.2225397 -3.2691169 52 | Epoch: 13, train -log_prob: -1.01, test -log_prob: -3.17, roc_val: 0.8584, roc_test: 0.8050 ,h: 9.918212890625e-05 53 | rho: 100000.0, alpha 3.2303810119628906, h 9.918212890625e-05 54 | =========================================== 55 | -2.252092 -3.2482724 -2.2047806 -3.2346206 56 | Epoch: 14, train -log_prob: 0.87, test -log_prob: -3.12, roc_val: 0.8778, roc_test: 0.7811 ,h: 4.57763671875e-05 57 | rho: 1000000.0, alpha 3.2303810119628906, h 4.57763671875e-05 58 | =========================================== 59 | Epoch: 15, train -log_prob: 0.46, test -log_prob: -3.02, roc_val: 0.8693, roc_test: 0.7725 ,h: 4.1961669921875e-05 60 | save model 15 epoch 61 | Epoch: 16, train -log_prob: -3.40, test -log_prob: -3.90, roc_val: 0.8608, roc_test: 0.7395 ,h: 2.6702880859375e-05 62 | save model 16 epoch 63 | Epoch: 17, train -log_prob: -3.68, test -log_prob: -3.93, roc_val: 0.7179, roc_test: 0.7496 ,h: 1.52587890625e-05 64 | save model 17 epoch 65 | Epoch: 18, train -log_prob: -3.73, test -log_prob: -3.97, roc_val: 0.8399, roc_test: 0.7550 ,h: 1.1444091796875e-05 66 | save model 18 epoch 67 | Epoch: 19, train -log_prob: -3.76, test -log_prob: -3.98, roc_val: 0.7907, roc_test: 0.7594 ,h: 1.1444091796875e-05 68 | save model 19 epoch 69 | Epoch: 20, train -log_prob: -3.77, test -log_prob: -4.00, roc_val: 0.8407, roc_test: 0.7589 ,h: 1.1444091796875e-05 70 | save model 20 epoch 71 | Epoch: 21, train -log_prob: -3.79, test -log_prob: -3.99, roc_val: 0.8352, roc_test: 0.7599 ,h: 7.62939453125e-06 72 | Epoch: 22, train -log_prob: -3.80, test -log_prob: -4.02, roc_val: 0.8365, roc_test: 0.7732 ,h: 7.62939453125e-06 73 | save model 22 epoch 74 | Epoch: 23, train -log_prob: -3.81, test -log_prob: -4.00, roc_val: 0.7389, roc_test: 0.7651 ,h: 7.62939453125e-06 75 | Epoch: 24, train -log_prob: -3.82, test -log_prob: -4.03, roc_val: 0.8376, roc_test: 0.7699 ,h: 7.62939453125e-06 76 | save model 24 epoch 77 | Epoch: 25, train -log_prob: -3.83, test -log_prob: -4.01, roc_val: 0.8032, roc_test: 0.7626 ,h: 7.62939453125e-06 78 | Epoch: 26, train -log_prob: -3.83, test -log_prob: -4.05, roc_val: 0.8129, roc_test: 0.7785 ,h: 7.62939453125e-06 79 | save model 26 epoch 80 | Epoch: 27, train -log_prob: -3.84, test -log_prob: -4.03, roc_val: 0.7634, roc_test: 0.7720 ,h: 7.62939453125e-06 81 | Epoch: 28, train -log_prob: -3.85, test -log_prob: -4.05, roc_val: 0.7682, roc_test: 0.7778 ,h: 7.62939453125e-06 82 | Epoch: 29, train -log_prob: -3.85, test -log_prob: -4.03, roc_val: 0.7887, roc_test: 0.7805 ,h: 7.62939453125e-06 83 | Epoch: 30, train -log_prob: -3.86, test -log_prob: -4.07, roc_val: 0.8116, roc_test: 0.7847 ,h: 7.62939453125e-06 84 | save model 30 epoch 85 | Epoch: 31, train -log_prob: -3.86, test -log_prob: -4.04, roc_val: 0.7788, roc_test: 0.7865 ,h: 7.62939453125e-06 86 | Epoch: 32, train -log_prob: -3.87, test -log_prob: -4.06, roc_val: 0.7928, roc_test: 0.7918 ,h: 7.62939453125e-06 87 | Epoch: 33, train -log_prob: -3.87, test -log_prob: -4.02, roc_val: 0.7828, roc_test: 0.7851 ,h: 7.62939453125e-06 88 | Epoch: 34, train -log_prob: -3.88, test -log_prob: -4.03, roc_val: 0.7377, roc_test: 0.7924 ,h: 7.62939453125e-06 89 | Epoch: 35, train -log_prob: -3.88, test -log_prob: -4.05, roc_val: 0.7093, roc_test: 0.7917 ,h: 7.62939453125e-06 90 | Epoch: 36, train -log_prob: -3.89, test -log_prob: -4.05, roc_val: 0.7840, roc_test: 0.8006 ,h: 7.62939453125e-06 91 | Epoch: 37, train -log_prob: -3.89, test -log_prob: -4.07, roc_val: 0.7289, roc_test: 0.7894 ,h: 7.62939453125e-06 92 | Epoch: 38, train -log_prob: -3.90, test -log_prob: -4.05, roc_val: 0.8108, roc_test: 0.7993 ,h: 7.62939453125e-06 93 | Epoch: 39, train -log_prob: -3.90, test -log_prob: -4.06, roc_val: 0.7141, roc_test: 0.7874 ,h: 7.62939453125e-06 94 | Epoch: 40, train -log_prob: -3.90, test -log_prob: -4.08, roc_val: 0.7490, roc_test: 0.8101 ,h: 7.62939453125e-06 95 | save model 40 epoch 96 | Epoch: 41, train -log_prob: -3.90, test -log_prob: -4.05, roc_val: 0.7221, roc_test: 0.7938 ,h: 7.62939453125e-06 97 | Epoch: 42, train -log_prob: -3.91, test -log_prob: -4.07, roc_val: 0.7360, roc_test: 0.8060 ,h: 7.62939453125e-06 98 | Epoch: 43, train -log_prob: -3.91, test -log_prob: -4.00, roc_val: 0.8063, roc_test: 0.7809 ,h: 7.62939453125e-06 99 | Epoch: 44, train -log_prob: -3.92, test -log_prob: -4.06, roc_val: 0.7648, roc_test: 0.8081 ,h: 7.62939453125e-06 100 | -------------------------------------------------------------------------------- /models/DROCC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class LSTM_FC(nn.Module): 8 | def __init__(self, 9 | input_dim=32, 10 | num_classes=1, 11 | num_hidden_nodes=8 12 | ): 13 | 14 | super(LSTM_FC, self).__init__() 15 | self.input_dim = input_dim 16 | self.num_classes = num_classes 17 | self.num_hidden_nodes = num_hidden_nodes 18 | self.encoder = nn.LSTM(input_size=self.input_dim, hidden_size=self.num_hidden_nodes, 19 | num_layers=1, batch_first=True) 20 | self.fc = nn.Linear(self.num_hidden_nodes, self.num_classes) 21 | activ = nn.ReLU(True) 22 | 23 | def forward(self, input): 24 | features = self.encoder(input)[0][:,-1,:] 25 | # pdb.set_trace() 26 | logits = self.fc(features) 27 | return logits 28 | 29 | def half_forward_start(self, x): 30 | features = self.encoder(x)[0][:,-1,:] 31 | return features 32 | 33 | def half_forward_end(self, x): 34 | logits = self.fc(x) 35 | return logits 36 | 37 | class DROCCTrainer: 38 | """ 39 | Trainer class that implements the DROCC algorithm proposed in 40 | https://arxiv.org/abs/2002.12718 41 | """ 42 | 43 | def __init__(self, model, optimizer, lamda, radius, gamma, device): 44 | """Initialize the DROCC Trainer class 45 | Parameters 46 | ---------- 47 | model: Torch neural network object 48 | optimizer: Total number of epochs for training. 49 | lamda: Weight given to the adversarial loss 50 | radius: Radius of hypersphere to sample points from. 51 | gamma: Parameter to vary projection. 52 | device: torch.device object for device to use. 53 | """ 54 | self.model = model 55 | self.optimizer = optimizer 56 | self.lamda = lamda 57 | self.radius = radius 58 | self.gamma = gamma 59 | self.device = device 60 | 61 | def train(self, train_loader, lr_scheduler, total_epochs, save_path, name, 62 | only_ce_epochs=5, ascent_step_size=0.001, ascent_num_steps=50): 63 | """Trains the model on the given training dataset with periodic 64 | evaluation on the validation dataset. 65 | Parameters 66 | ---------- 67 | train_loader: Dataloader object for the training dataset. 68 | val_loader: Dataloader object for the validation dataset. 69 | learning_rate: Initial learning rate for training. 70 | total_epochs: Total number of epochs for training. 71 | only_ce_epochs: Number of epochs for initial pretraining. 72 | ascent_step_size: Step size for gradient ascent for adversarial 73 | generation of negative points. 74 | ascent_num_steps: Number of gradient ascent steps for adversarial 75 | generation of negative points. 76 | metric: Metric used for evaluation (AUC / F1). 77 | """ 78 | self.ascent_num_steps = ascent_num_steps 79 | self.ascent_step_size = ascent_step_size 80 | for epoch in range(total_epochs): 81 | #Make the weights trainable 82 | self.model.train() 83 | 84 | #Placeholder for the respective 2 loss values 85 | epoch_adv_loss = 0.0 #AdvLoss 86 | epoch_ce_loss = 0.0 #Cross entropy Loss 87 | 88 | batch_idx = -1 89 | for data in train_loader: 90 | batch_idx += 1 91 | data = data.to(self.device) 92 | target = torch.ones([data.shape[0]], dtype=torch.float32).to(self.device) 93 | 94 | data = torch.transpose(data, dim0=1, dim1=2) 95 | data = data.reshape(data.shape[0], data.shape[1], data.shape[2]*data.shape[3]) 96 | 97 | self.optimizer.zero_grad() 98 | 99 | # Extract the logits for cross entropy loss 100 | logits = self.model(data) 101 | logits = torch.squeeze(logits, dim = 1) 102 | ce_loss = F.binary_cross_entropy_with_logits(logits, target) 103 | # Add to the epoch variable for printing average CE Loss 104 | epoch_ce_loss += ce_loss.item() 105 | 106 | ''' 107 | Adversarial Loss is calculated only for the positive data points (label==1). 108 | ''' 109 | if epoch >= only_ce_epochs: 110 | data = data[target == 1] 111 | # AdvLoss 112 | adv_loss = self.one_class_adv_loss(data) 113 | epoch_adv_loss += adv_loss.item() 114 | 115 | loss = ce_loss + adv_loss * self.lamda 116 | else: 117 | # If only CE based training has to be done 118 | loss = ce_loss 119 | 120 | # Backprop 121 | loss.backward() 122 | self.optimizer.step() 123 | lr_scheduler.step() 124 | 125 | epoch_ce_loss = epoch_ce_loss/(batch_idx + 1) #Average CE Loss 126 | epoch_adv_loss = epoch_adv_loss/(batch_idx + 1) #Average AdvLoss 127 | 128 | print('Epoch: {}, CE Loss: {}, AdvLoss: {}'.format( 129 | epoch, epoch_ce_loss, epoch_adv_loss)) 130 | self.save(os.path.join(save_path, "{}_{}.pt".format(name, epoch))) 131 | def test(self, test_loader): 132 | """Evaluate the model on the given test dataset. 133 | Parameters 134 | ---------- 135 | test_loader: Dataloader object for the test dataset. 136 | metric: Metric used for evaluation (AUC / F1). 137 | """ 138 | self.model.eval() 139 | scores = [] 140 | with torch.no_grad(): 141 | for data in test_loader: 142 | data = data.to(self.device) 143 | 144 | data = torch.transpose(data, dim0=1, dim1=2) 145 | data = data.reshape(data.shape[0], data.shape[1], data.shape[2]*data.shape[3]) 146 | 147 | 148 | logits = self.model(data).cpu().numpy() 149 | scores.append(logits) 150 | scores = -np.concatenate(scores) 151 | 152 | return scores 153 | 154 | 155 | def one_class_adv_loss(self, x_train_data): 156 | """Computes the adversarial loss: 157 | 1) Sample points initially at random around the positive training 158 | data points 159 | 2) Gradient ascent to find the most optimal point in set N_i(r) 160 | classified as +ve (label=0). This is done by maximizing 161 | the CE loss wrt label 0 162 | 3) Project the points between spheres of radius R and gamma * R 163 | (set N_i(r)) 164 | 4) Pass the calculated adversarial points through the model, 165 | and calculate the CE loss wrt target class 0 166 | 167 | Parameters 168 | ---------- 169 | x_train_data: Batch of data to compute loss on. 170 | """ 171 | batch_size = len(x_train_data) 172 | # Randomly sample points around the training data 173 | # We will perform SGD on these to find the adversarial points 174 | x_adv = torch.randn(x_train_data.shape).to(self.device).detach().requires_grad_() 175 | x_adv_sampled = x_adv + x_train_data 176 | 177 | for step in range(self.ascent_num_steps): 178 | with torch.enable_grad(): 179 | 180 | new_targets = torch.zeros(batch_size, 1).to(self.device) 181 | new_targets = torch.squeeze(new_targets) 182 | new_targets = new_targets.to(torch.float) 183 | 184 | logits = self.model(x_adv_sampled) 185 | logits = torch.squeeze(logits, dim = 1) 186 | new_loss = F.binary_cross_entropy_with_logits(logits, new_targets) 187 | 188 | grad = torch.autograd.grad(new_loss, [x_adv_sampled])[0] 189 | grad_norm = torch.norm(grad, p=2, dim = tuple(range(1, grad.dim()))) 190 | grad_norm = grad_norm.view(-1, *[1]*(grad.dim()-1)) 191 | grad_normalized = grad/grad_norm 192 | with torch.no_grad(): 193 | x_adv_sampled.add_(self.ascent_step_size * grad_normalized) 194 | 195 | if (step + 1) % 10==0: 196 | # Project the normal points to the set N_i(r) 197 | h = x_adv_sampled - x_train_data 198 | norm_h = torch.sqrt(torch.sum(h**2, 199 | dim=tuple(range(1, h.dim())))) 200 | alpha = torch.clamp(norm_h, self.radius, 201 | self.gamma * self.radius).to(self.device) 202 | # Make use of broadcast to project h 203 | proj = (alpha/norm_h).view(-1, *[1] * (h.dim()-1)) 204 | h = proj * h 205 | x_adv_sampled = x_train_data + h #These adv_points are now on the surface of hyper-sphere 206 | 207 | adv_pred = self.model(x_adv_sampled) 208 | adv_pred = torch.squeeze(adv_pred, dim=1) 209 | adv_loss = F.binary_cross_entropy_with_logits(adv_pred, (new_targets * 0)) 210 | 211 | return adv_loss 212 | 213 | def save(self, path): 214 | torch.save(self.model.state_dict(),path) 215 | 216 | def load(self, path): 217 | self.model.load_state_dict(torch.load(path)) -------------------------------------------------------------------------------- /models/DeepSAD.py: -------------------------------------------------------------------------------- 1 | 2 | #%% 3 | import json 4 | import torch 5 | import logging 6 | import time 7 | import torch 8 | 9 | import torch.optim as optim 10 | 11 | class AETrainer: 12 | 13 | def __init__(self, device: str = 'cuda'): 14 | 15 | self.device = device 16 | 17 | def train(self, train_loader, ae_net, args): 18 | logger = logging.getLogger() 19 | 20 | # Set device for network 21 | ae_net = ae_net.to(self.device) 22 | 23 | # Set optimizer (Adam optimizer for now) 24 | optimizer = optim.Adam(ae_net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 25 | 26 | # Training 27 | print('Starting pretraining...') 28 | start_time = time.time() 29 | ae_net.train() 30 | for epoch in range(10): 31 | 32 | 33 | loss_epoch = 0.0 34 | n_batches = 0 35 | epoch_start_time = time.time() 36 | for data in train_loader: 37 | 38 | if isinstance(data, list): 39 | data = data[0] 40 | 41 | x = data.to(self.device) 42 | x = torch.transpose(x, dim0=2, dim1=3) 43 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 44 | 45 | # Zero the network parameter gradients 46 | optimizer.zero_grad() 47 | 48 | # Update network parameters via backpropagation: forward + backward + optimize 49 | outputs = ae_net(inputs) 50 | scores = torch.sum((outputs - inputs) ** 2, dim=tuple(range(1, outputs.dim()))) 51 | loss = torch.mean(scores) 52 | loss.backward() 53 | optimizer.step() 54 | 55 | loss_epoch += loss.item() 56 | n_batches += 1 57 | 58 | # log epoch statistics 59 | epoch_train_time = time.time() - epoch_start_time 60 | print(' Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}' 61 | .format(epoch + 1, 10, epoch_train_time, loss_epoch / n_batches)) 62 | 63 | pretrain_time = time.time() - start_time 64 | print('Pretraining time: %.3f' % pretrain_time) 65 | print('Finished pretraining.') 66 | 67 | return ae_net 68 | 69 | class DeepSVDDTrainer: 70 | 71 | def __init__(self, device: str = 'cuda'): 72 | 73 | self.device = device 74 | # Deep SVDD parameters 75 | 76 | self.c = None 77 | 78 | 79 | def train(self, train_loader, net, args): 80 | self.args = args 81 | 82 | # Set device for network 83 | net = net.to(self.device) 84 | 85 | # Set optimizer (Adam optimizer for now) 86 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 87 | 88 | # Set learning rate scheduler 89 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=0.1) 90 | 91 | # Initialize hypersphere center c (if c not loaded) 92 | if self.c is None: 93 | print('Initializing center c...') 94 | self.c = self.init_center_c(train_loader, net) 95 | print(self.c.shape) 96 | print('Center c initialized.') 97 | 98 | # Training 99 | print('Starting training...') 100 | start_time = time.time() 101 | net.train() 102 | 103 | save_path = os.path.join(args.output_dir,args.name) 104 | if not os.path.exists(save_path): 105 | os.makedirs(save_path) 106 | 107 | for epoch in range(args.n_epochs): 108 | 109 | scheduler.step() 110 | 111 | loss_epoch = 0.0 112 | n_batches = 0 113 | epoch_start_time = time.time() 114 | 115 | for data in train_loader: 116 | x = data.to(self.device) 117 | x = torch.transpose(x, dim0=2, dim1=3) 118 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 119 | 120 | # Zero the network parameter gradients 121 | optimizer.zero_grad() 122 | 123 | # Update network parameters via backpropagation: forward + backward + optimize 124 | outputs = net(inputs).squeeze().mean(dim=-1) 125 | dist = torch.sum((outputs - self.c) ** 2, dim=1) 126 | 127 | loss = torch.mean(dist) 128 | loss.backward() 129 | optimizer.step() 130 | 131 | loss_epoch += loss.item() 132 | n_batches += 1 133 | 134 | # log epoch statistics 135 | epoch_train_time = time.time() - epoch_start_time 136 | print(' Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}' 137 | .format(epoch + 1, args.n_epochs, epoch_train_time, loss_epoch / n_batches)) 138 | torch.save({'c': self.c, 'net_dict': net.state_dict()}, os.path.join(save_path, "{}_{}.pt".format(args.name, epoch))) 139 | self.train_time = time.time() - start_time 140 | print('Training time: %.3f' % self.train_time) 141 | 142 | print('Finished training.') 143 | 144 | return net 145 | 146 | def init_center_c(self, train_loader, net, eps=0.1): 147 | """Initialize hypersphere center c as the mean from an initial forward pass on the data.""" 148 | n_samples = 0 149 | c = 0.0 150 | 151 | net.eval() 152 | with torch.no_grad(): 153 | for data in train_loader: 154 | # get the inputs of the batch 155 | x = data.to(self.device) 156 | x = torch.transpose(x, dim0=2, dim1=3) 157 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 158 | outputs = net(inputs).squeeze() 159 | n_samples += outputs.shape[0] 160 | c += torch.sum(outputs, dim=0) 161 | 162 | c /= n_samples 163 | 164 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. 165 | c[(abs(c) < eps) & (c < 0)] = -eps 166 | c[(abs(c) < eps) & (c > 0)] = eps 167 | 168 | return c.mean(dim=-1) 169 | 170 | 171 | from models.RNN import RecurrentAE 172 | from models.GAN import CNNAE 173 | 174 | 175 | class DeepSVDD(object): 176 | 177 | def __init__(self, n_features, hidden_size, device): 178 | 179 | self.c = None # hypersphere center c 180 | 181 | 182 | self.trainer = None 183 | 184 | # if encoder=='RNN': 185 | # self.ae_net = RecurrentAE(n_features, hidden_size, device) 186 | self.ae_net = CNNAE(n_features, hidden_size).to(device) 187 | self.net = self.ae_net.encoder 188 | 189 | self.ae_trainer = None 190 | self.results = { 191 | 'test_auc': None 192 | } 193 | 194 | def train(self, dataset, args, device: str = 'cuda'): 195 | """Trains the Deep SVDD model on the training data.""" 196 | 197 | self.trainer = DeepSVDDTrainer(device=device) 198 | # Get the model 199 | self.trainer.train(dataset, self.net, args) 200 | self.c = self.trainer.c 201 | 202 | def test(self, test_loader, delta_t, sigma, device): 203 | from utils import roc_auc_all 204 | import numpy as np 205 | self.net.eval() 206 | self.net.to(device) 207 | loss = [] 208 | 209 | with torch.no_grad(): 210 | for data in test_loader: 211 | 212 | x = data.to(device) 213 | x = torch.transpose(x, dim0=2, dim1=3) 214 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 215 | outputs = self.net(inputs).squeeze().mean(dim=-1) 216 | batch_loss= torch.sum((outputs - self.c) ** 2, dim=1).cpu().numpy() 217 | loss.append(batch_loss) 218 | loss = np.concatenate(loss) 219 | 220 | auc_score, fps,tps = roc_auc_all(loss, delta_t, sigma) 221 | print("meann: {:.4f}, median: {:.4f}, auc:{:.4f}".format(np.mean(loss), np.median(loss),auc_score))# %% 222 | self.results['test_auc'] = auc_score 223 | return auc_score, fps,tps 224 | 225 | def pretrain(self, train_loader, args, device): 226 | """Pretrains the weights for the Deep SVDD network \phi via autoencoder.""" 227 | 228 | self.ae_trainer = AETrainer(device=device) 229 | self.ae_net = self.ae_trainer.train(train_loader, self.ae_net, args) 230 | self.net = self.ae_net.encoder 231 | 232 | def save_model(self, export_model): 233 | """Save Deep SVDD model to export_model.""" 234 | 235 | net_dict = self.net.state_dict() 236 | 237 | torch.save({'c': self.c, 238 | 'net_dict': net_dict}, export_model) 239 | 240 | def load_model(self, model_path): 241 | """Load Deep SVDD model from model_path.""" 242 | 243 | model_dict = torch.load(model_path) 244 | 245 | self.c = model_dict['c'] 246 | self.net.load_state_dict(model_dict['net_dict']) 247 | 248 | def save_results(self, export_json): 249 | """Save results dict to a JSON-file.""" 250 | with open(export_json, 'w') as fp: 251 | json.dump(self.results, fp) 252 | 253 | # %% 254 | import os 255 | class DeepSADTrainer: 256 | 257 | def __init__(self, device: str = 'cuda'): 258 | 259 | self.device = device 260 | 261 | self.c = None 262 | 263 | 264 | def train(self, train_loader, net, args): 265 | self.args = args 266 | 267 | # Set device for network 268 | net = net.to(self.device) 269 | 270 | # Set optimizer (Adam optimizer for now) 271 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 272 | 273 | # Set learning rate scheduler 274 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=0.1) 275 | 276 | # Initialize hypersphere center c (if c not loaded) 277 | if self.c is None: 278 | print('Initializing center c...') 279 | self.c = self.init_center_c(train_loader, net) 280 | print('Center c initialized.') 281 | 282 | # Training 283 | print('Starting training...') 284 | start_time = time.time() 285 | net.train() 286 | 287 | save_path = os.path.join(args.output_dir,args.name) 288 | if not os.path.exists(save_path): 289 | os.makedirs(save_path) 290 | 291 | for epoch in range(args.n_epochs): 292 | 293 | 294 | 295 | loss_epoch = 0.0 296 | n_batches = 0 297 | epoch_start_time = time.time() 298 | 299 | for data, semi_targets in train_loader: 300 | 301 | x = data.to(self.device) 302 | x = torch.transpose(x, dim0=2, dim1=3) 303 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 304 | 305 | semi_targets = semi_targets.to(self.device) 306 | # Zero the network parameter gradients 307 | optimizer.zero_grad() 308 | 309 | # Update network parameters via backpropagation: forward + backward + optimize 310 | outputs = net(inputs).squeeze().mean(dim=-1) 311 | dist = torch.sum((outputs - self.c) ** 2, dim=1) 312 | 313 | losses = torch.where(semi_targets == 0, dist, args.eta * ((dist + 1e-6) ** semi_targets.float())) 314 | loss = torch.mean(losses) 315 | loss.backward() 316 | optimizer.step() 317 | 318 | loss_epoch += loss.item() 319 | n_batches += 1 320 | scheduler.step() 321 | # log epoch statistics 322 | 323 | epoch_train_time = time.time() - epoch_start_time 324 | print(' Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f}' 325 | .format(epoch + 1, args.n_epochs, epoch_train_time, loss_epoch / n_batches)) 326 | torch.save({'c': self.c, 'net_dict': net.state_dict()}, os.path.join(save_path, "{}_{}.pt".format(args.name, epoch))) 327 | 328 | self.train_time = time.time() - start_time 329 | print('Training time: %.3f' % self.train_time) 330 | 331 | print('Finished training.') 332 | 333 | return net 334 | 335 | def init_center_c(self, train_loader, net, eps=0.1): 336 | """Initialize hypersphere center c as the mean from an initial forward pass on the data.""" 337 | n_samples = 0 338 | c = 0.0 339 | 340 | net.eval() 341 | with torch.no_grad(): 342 | for data, _ in train_loader: 343 | # get the inputs of the batch 344 | x = data.to(self.device) 345 | x = torch.transpose(x, dim0=2, dim1=3) 346 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 347 | outputs = net(inputs).squeeze() 348 | n_samples += outputs.shape[0] 349 | c += torch.sum(outputs, dim=0) 350 | 351 | c /= n_samples 352 | 353 | # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. 354 | c[(abs(c) < eps) & (c < 0)] = -eps 355 | c[(abs(c) < eps) & (c > 0)] = eps 356 | 357 | return c.mean(dim=-1) 358 | class DeepSAD(object): 359 | 360 | def __init__(self, n_features, hidden_size, device): 361 | 362 | self.c = None # hypersphere center c 363 | 364 | 365 | self.trainer = None 366 | 367 | self.ae_net = CNNAE(n_features, hidden_size).to(device) 368 | self.net = self.ae_net.encoder 369 | 370 | self.ae_trainer = None 371 | self.results = { 372 | 'test_auc': None 373 | } 374 | 375 | def train(self, dataset, args, device: str = 'cuda'): 376 | 377 | self.trainer = DeepSADTrainer(device=device) 378 | # Get the model 379 | self.trainer.train(dataset, self.net, args) 380 | self.c = self.trainer.c 381 | 382 | def test(self, test_loader, delta_t, sigma, device): 383 | from utils import roc_auc_all 384 | import numpy as np 385 | self.net.eval() 386 | self.net.to(device) 387 | loss = [] 388 | 389 | with torch.no_grad(): 390 | for data in test_loader: 391 | 392 | x = data.to(device) 393 | x = torch.transpose(x, dim0=2, dim1=3) 394 | inputs = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 395 | outputs = self.net(inputs).squeeze().mean(dim=-1) 396 | batch_loss= torch.sum((outputs - self.c) ** 2, dim=1).cpu().numpy() 397 | loss.append(batch_loss) 398 | loss = np.concatenate(loss) 399 | 400 | auc_score, fps,tps = roc_auc_all(loss, delta_t, sigma) 401 | print("mean: {:.4f}, median: {:.4f}, auc:{:.4f}".format(np.mean(loss), np.median(loss),auc_score))# %% 402 | self.results['test_auc'] = auc_score 403 | return auc_score,fps,tps 404 | 405 | def pretrain(self, train_loader, args, device): 406 | 407 | self.ae_trainer = AETrainer(device=device) 408 | self.ae_net = self.ae_trainer.train(train_loader, self.ae_net, args) 409 | self.net = self.ae_net.encoder 410 | 411 | def save_model(self, export_model): 412 | """Save Deep SVDD model to export_model.""" 413 | 414 | net_dict = self.net.state_dict() 415 | 416 | torch.save({'c': self.c, 417 | 'net_dict': net_dict}, export_model) 418 | 419 | def load_model(self, model_path, load_ae=False): 420 | """Load Deep SVDD model from model_path.""" 421 | 422 | model_dict = torch.load(model_path) 423 | 424 | self.c = model_dict['c'] 425 | self.net.load_state_dict(model_dict['net_dict']) 426 | 427 | def save_results(self, export_json): 428 | """Save results dict to a JSON-file.""" 429 | with open(export_json, 'w') as fp: 430 | json.dump(self.results, fp) 431 | # %% 432 | -------------------------------------------------------------------------------- /models/GAN.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timeit import default_timer as timer 7 | def ConvEncoder(activation = nn.LeakyReLU, in_channels:int = 3, n_c:int = 64, 8 | k_size:int = 5): 9 | 10 | enc = nn.Sequential(*(nn.Conv1d(in_channels, n_c, k_size, stride=2, padding=2), 11 | nn.BatchNorm1d(n_c), 12 | activation(), 13 | nn.Conv1d(n_c, n_c*2, k_size, stride=2, padding=2), 14 | nn.BatchNorm1d(n_c*2), 15 | activation(), 16 | nn.Conv1d(n_c*2, n_c*4, k_size, stride=2, padding=2), 17 | nn.BatchNorm1d(n_c*4), 18 | activation())) 19 | return enc 20 | 21 | def ConvDecoder(activation = nn.LeakyReLU, in_channels:int = 3, n_c:int = 64, 22 | k_size:int = 5): 23 | 24 | decoder = nn.Sequential(*(nn.ConvTranspose1d(n_c*4, n_c*2, k_size, stride=2, padding=2, output_padding=0), 25 | torch.nn.BatchNorm1d(n_c*2), 26 | activation(), 27 | torch.nn.ConvTranspose1d(n_c*2, n_c, k_size,stride=2, padding=2, output_padding=1), 28 | torch.nn.BatchNorm1d(n_c), 29 | activation(), 30 | torch.nn.ConvTranspose1d(n_c, in_channels, k_size,stride=2, padding=2, output_padding=1))) 31 | return decoder 32 | 33 | class CNNAE(torch.nn.Module): 34 | """Recurrent autoencoder""" 35 | 36 | def __init__(self,in_channels:int = 3, n_channels:int = 16, 37 | kernel_size:int = 5): 38 | super(CNNAE, self).__init__() 39 | 40 | # Encoder and decoder argsuration 41 | activation = torch.nn.LeakyReLU 42 | self.in_channels = in_channels 43 | self.n_c = n_channels 44 | self.k_size = kernel_size 45 | 46 | self.encoder = ConvEncoder(activation, in_channels, n_channels, kernel_size) 47 | 48 | self.decoder = ConvDecoder(activation, in_channels, n_channels, kernel_size) 49 | 50 | 51 | def forward(self, x:torch.Tensor): 52 | 53 | z = self.encoder.forward(x) 54 | 55 | x_out = self.decoder.forward(z) 56 | 57 | return x_out 58 | 59 | class R_Net(torch.nn.Module): 60 | 61 | def __init__(self, activation = torch.nn.LeakyReLU, in_channels:int = 3, n_channels:int = 16, 62 | kernel_size:int = 5, std:float = 0.2): 63 | 64 | super(R_Net, self).__init__() 65 | 66 | self.activation = activation 67 | self.in_channels = in_channels 68 | self.n_c = n_channels 69 | self.k_size = kernel_size 70 | self.std = std 71 | 72 | self.Encoder = ConvEncoder(activation, in_channels, n_channels, kernel_size) 73 | 74 | self.Decoder = ConvDecoder(activation, in_channels, n_channels, kernel_size) 75 | 76 | def forward(self, x:torch.Tensor, noise:bool = True): 77 | 78 | x_hat = self.add_noise(x) if noise else x 79 | z = self.Encoder.forward(x_hat) 80 | 81 | x_out = self.Decoder.forward(z) 82 | 83 | return x_out 84 | 85 | def add_noise(self, x): 86 | 87 | noise = torch.randn_like(x) * self.std 88 | x_hat = x + noise 89 | 90 | return x_hat 91 | 92 | class D_Net(torch.nn.Module): 93 | 94 | def __init__(self, in_resolution:int, activation = torch.nn.LeakyReLU, in_channels:int = 3, n_channels:int = 16, kernel_size:int = 5): 95 | 96 | super(D_Net, self).__init__() 97 | 98 | self.activation = activation 99 | self.in_resolution = in_resolution 100 | self.in_channels = in_channels 101 | self.n_c = n_channels 102 | self.k_size = kernel_size 103 | 104 | self.cnn = ConvEncoder(activation, in_channels, n_channels, kernel_size) 105 | 106 | # Compute output dimension after conv part of D network 107 | 108 | self.out_dim = self._compute_out_dim() 109 | 110 | self.fc = torch.nn.Linear(self.out_dim, 1) 111 | 112 | def _compute_out_dim(self): 113 | 114 | test_x = torch.Tensor(1, self.in_channels, self.in_resolution) 115 | for p in self.cnn.parameters(): 116 | p.requires_grad = False 117 | test_x = self.cnn(test_x) 118 | out_dim = torch.prod(torch.tensor(test_x.shape[1:])).item() 119 | for p in self.cnn.parameters(): 120 | p.requires_grad = True 121 | 122 | return out_dim 123 | 124 | def forward(self, x:torch.Tensor): 125 | 126 | x = self.cnn(x) 127 | 128 | x = torch.flatten(x, start_dim = 1) 129 | 130 | out = self.fc(x) 131 | 132 | return out 133 | 134 | def R_Loss(d_net: torch.nn.Module, x_real: torch.Tensor, x_fake: torch.Tensor, lambd: float) -> dict: 135 | 136 | pred = d_net(x_fake) 137 | y = torch.ones_like(pred) 138 | 139 | rec_loss = F.mse_loss(x_fake, x_real) 140 | gen_loss = F.binary_cross_entropy_with_logits(pred, y) # generator loss 141 | 142 | L_r = gen_loss + lambd * rec_loss 143 | 144 | return {'rec_loss' : rec_loss, 'gen_loss' : gen_loss, 'L_r' : L_r} 145 | 146 | def D_Loss(d_net: torch.nn.Module, x_real: torch.Tensor, x_fake: torch.Tensor) -> torch.Tensor: 147 | 148 | pred_real = d_net(x_real) 149 | pred_fake = d_net(x_fake.detach()) 150 | 151 | y_real = torch.ones_like(pred_real) 152 | y_fake = torch.zeros_like(pred_fake) 153 | 154 | real_loss = F.binary_cross_entropy_with_logits(pred_real, y_real) 155 | fake_loss = F.binary_cross_entropy_with_logits(pred_fake, y_fake) 156 | 157 | return real_loss + fake_loss 158 | 159 | # Wasserstein GAN loss (https://arxiv.org/abs/1701.07875) 160 | 161 | def R_WLoss(d_net: torch.nn.Module, x_real: torch.Tensor, x_fake: torch.Tensor, lambd: float) -> dict: 162 | 163 | pred = torch.sigmoid(d_net(x_fake)) 164 | 165 | rec_loss = F.mse_loss(x_fake, x_real) 166 | gen_loss = -torch.mean(pred) # Wasserstein G loss: - E[ D(G(x)) ] 167 | 168 | L_r = gen_loss + lambd * rec_loss 169 | 170 | return {'rec_loss' : rec_loss, 'gen_loss' : gen_loss, 'L_r' : L_r} 171 | 172 | def D_WLoss(d_net: torch.nn.Module, x_real: torch.Tensor, x_fake: torch.Tensor) -> torch.Tensor: 173 | 174 | pred_real = torch.sigmoid(d_net(x_real)) 175 | pred_fake = torch.sigmoid(d_net(x_fake.detach())) 176 | 177 | dis_loss = -torch.mean(pred_real) + torch.mean(pred_fake) # Wasserstein D loss: -E[D(x_real)] + E[D(x_fake)] 178 | 179 | return dis_loss 180 | 181 | # %% 182 | def train_model(r_net: torch.nn.Module, 183 | d_net: torch.nn.Module, 184 | train_dataset: torch.utils.data.Dataset, 185 | valid_dataset: torch.utils.data.Dataset, 186 | r_loss = R_Loss, 187 | d_loss = D_Loss, 188 | lr_scheduler = None, 189 | optimizer_class = torch.optim.Adam, 190 | optim_r_params: dict = {}, 191 | optim_d_params: dict = {}, 192 | learning_rate: float = 0.001, 193 | scheduler_r_params: dict = {}, 194 | scheduler_d_params: dict = {}, 195 | batch_size: int = 1024, 196 | max_epochs: int = 40, 197 | epoch_step: int = 1, 198 | save_step: int = 5, 199 | lambd: float = 0.2, 200 | device: torch.device = torch.device('cuda'), 201 | save_path: str = ".") -> tuple: 202 | 203 | optim_r = optimizer_class(r_net.parameters(), lr = learning_rate, **optim_r_params) 204 | optim_d = optimizer_class(d_net.parameters(), lr = learning_rate, **optim_d_params) 205 | 206 | if lr_scheduler: 207 | scheduler_r = lr_scheduler(optim_r, **scheduler_r_params) 208 | scheduler_d = lr_scheduler(optim_d, **scheduler_d_params) 209 | 210 | train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size) 211 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size) 212 | 213 | for epoch in range(max_epochs): 214 | 215 | start = timer() 216 | train_metrics = train_single_epoch(r_net, d_net, optim_r, optim_d, r_loss, d_loss, train_loader, lambd, device) 217 | valid_metrics = validate_single_epoch(r_net, d_net, r_loss, d_loss, valid_loader, device) 218 | time = timer() - start 219 | 220 | 221 | if epoch % epoch_step == 0: 222 | print(f'Epoch {epoch}:') 223 | print('Train Metrics:', train_metrics) 224 | print('Val Metrics:', valid_metrics) 225 | print(f'TIME: {time:.2f} s') 226 | 227 | if lr_scheduler: 228 | scheduler_r.step() 229 | scheduler_d.step() 230 | 231 | if epoch % save_step == 0: 232 | torch.save(r_net.state_dict(), os.path.join(save_path, "r_net_{}.pt".format(epoch))) 233 | torch.save(d_net.state_dict(), os.path.join(save_path, "d_net_{}.pt".format(epoch))) 234 | print(f'Saving model on epoch {epoch}') 235 | 236 | return (r_net, d_net) 237 | 238 | def train_single_epoch(r_net, d_net, optim_r, optim_d, r_loss, d_loss, train_loader, lambd, device) -> dict: 239 | 240 | r_net.train() 241 | d_net.train() 242 | 243 | train_metrics = {'rec_loss' : 0, 'gen_loss' : 0, 'dis_loss' : 0} 244 | 245 | for data in train_loader: 246 | 247 | x = data.to(device) 248 | x = torch.transpose(x, dim0=2, dim1=3) 249 | x_real = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 250 | 251 | x_fake = r_net(x_real) 252 | d_net.zero_grad() 253 | 254 | dis_loss = d_loss(d_net, x_real, x_fake) 255 | 256 | dis_loss.backward() 257 | optim_d.step() 258 | 259 | r_net.zero_grad() 260 | 261 | r_metrics = r_loss(d_net, x_real, x_fake, lambd) # L_r = gen_loss + lambda * rec_loss 262 | 263 | r_metrics['L_r'].backward() 264 | optim_r.step() 265 | 266 | train_metrics['rec_loss'] += r_metrics['rec_loss'] 267 | train_metrics['gen_loss'] += r_metrics['gen_loss'] 268 | train_metrics['dis_loss'] += dis_loss 269 | 270 | train_metrics['rec_loss'] = train_metrics['rec_loss'].item() / (len(train_loader.dataset) / train_loader.batch_size) 271 | train_metrics['gen_loss'] = train_metrics['gen_loss'].item() / (len(train_loader.dataset) / train_loader.batch_size) 272 | train_metrics['dis_loss'] = train_metrics['dis_loss'].item() / (len(train_loader.dataset) / train_loader.batch_size) 273 | 274 | return train_metrics 275 | 276 | def validate_single_epoch(r_net, d_net, r_loss, d_loss, valid_loader, device) -> dict: 277 | 278 | r_net.eval() 279 | d_net.eval() 280 | 281 | valid_metrics = {'rec_loss' : 0, 'gen_loss' : 0, 'dis_loss' : 0} 282 | 283 | with torch.no_grad(): 284 | for data in valid_loader: 285 | 286 | x = data.to(device) 287 | x = torch.transpose(x, dim0=2, dim1=3) 288 | x_real = x.reshape(x.shape[0], x.shape[1]*x.shape[2], x.shape[3]) 289 | 290 | x_fake = r_net(x_real) 291 | 292 | dis_loss = d_loss(d_net, x_real, x_fake) 293 | 294 | r_metrics = r_loss(d_net, x_real, x_fake, 0) 295 | 296 | valid_metrics['rec_loss'] += r_metrics['rec_loss'] 297 | valid_metrics['gen_loss'] += r_metrics['gen_loss'] 298 | valid_metrics['dis_loss'] += dis_loss 299 | 300 | valid_metrics['rec_loss'] = valid_metrics['rec_loss'].item() / (len(valid_loader.dataset) / valid_loader.batch_size) 301 | valid_metrics['gen_loss'] = valid_metrics['gen_loss'].item() / (len(valid_loader.dataset) / valid_loader.batch_size) 302 | valid_metrics['dis_loss'] = valid_metrics['dis_loss'].item() / (len(valid_loader.dataset) / valid_loader.batch_size) 303 | 304 | return valid_metrics 305 | 306 | 307 | 308 | # %% 309 | -------------------------------------------------------------------------------- /models/GANF.py: -------------------------------------------------------------------------------- 1 | 2 | #%% 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.NF import MAF, RealNVP 6 | import torch 7 | 8 | class GNN(nn.Module): 9 | """ 10 | The GNN module applied in GANF 11 | """ 12 | def __init__(self, input_size, hidden_size): 13 | 14 | super(GNN, self).__init__() 15 | self.lin_n = nn.Linear(input_size, hidden_size) 16 | self.lin_r = nn.Linear(input_size, hidden_size, bias=False) 17 | self.lin_2 = nn.Linear(hidden_size, hidden_size) 18 | 19 | def forward(self, h, A): 20 | ## A: K X K 21 | ## H: N X K X L X D 22 | 23 | h_n = self.lin_n(torch.einsum('nkld,kj->njld',h,A)) 24 | h_r = self.lin_r(h[:,:,:-1]) 25 | h_n[:,:,1:] += h_r 26 | h = self.lin_2(F.relu(h_n)) 27 | 28 | return h 29 | 30 | 31 | class GANF(nn.Module): 32 | 33 | def __init__ (self, n_blocks, input_size, hidden_size, n_hidden ,dropout = 0.1, model="MAF", batch_norm=True): 34 | super(GANF, self).__init__() 35 | 36 | self.rnn = nn.LSTM(input_size=input_size,hidden_size=hidden_size,batch_first=True, dropout=dropout) 37 | self.gcn = GNN(input_size=hidden_size, hidden_size=hidden_size) 38 | if model=="MAF": 39 | self.nf = MAF(n_blocks, input_size, hidden_size, n_hidden, cond_label_size=hidden_size, batch_norm=batch_norm,activation='tanh') 40 | else: 41 | self.nf = RealNVP(n_blocks, input_size, hidden_size, n_hidden, cond_label_size=hidden_size, batch_norm=batch_norm) 42 | 43 | def forward(self, x, A): 44 | 45 | return self.test(x, A).mean() 46 | 47 | def test(self, x, A): 48 | # x: N X K X L X D 49 | full_shape = x.shape 50 | 51 | # reshape: N*K, L, D 52 | x = x.reshape((x.shape[0]*x.shape[1], x.shape[2], x.shape[3])) 53 | h,_ = self.rnn(x) 54 | 55 | # resahpe: N, K, L, H 56 | h = h.reshape((full_shape[0], full_shape[1], h.shape[1], h.shape[2])) 57 | 58 | 59 | h = self.gcn(h, A) 60 | 61 | # reshappe N*K*L,H 62 | h = h.reshape((-1,h.shape[3])) 63 | x = x.reshape((-1,full_shape[3])) 64 | 65 | log_prob = self.nf.log_prob(x,h).reshape([full_shape[0],-1])#*full_shape[1]*full_shape[2] 66 | log_prob = log_prob.mean(dim=1) 67 | 68 | return log_prob 69 | 70 | def locate(self, x, A): 71 | # x: N X K X L X D 72 | full_shape = x.shape 73 | 74 | # reshape: N*K, L, D 75 | x = x.reshape((x.shape[0]*x.shape[1], x.shape[2], x.shape[3])) 76 | h,_ = self.rnn(x) 77 | 78 | # resahpe: N, K, L, H 79 | h = h.reshape((full_shape[0], full_shape[1], h.shape[1], h.shape[2])) 80 | 81 | 82 | h = self.gcn(h, A) 83 | 84 | # reshappe N*K*L,H 85 | h = h.reshape((-1,h.shape[3])) 86 | x = x.reshape((-1,full_shape[3])) 87 | 88 | log_prob = self.nf.log_prob(x,h).reshape([full_shape[0],full_shape[1],-1])#*full_shape[1]*full_shape[2] 89 | log_prob = log_prob.mean(dim=2) 90 | 91 | return log_prob 92 | -------------------------------------------------------------------------------- /models/NF.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributions as D 7 | import math 8 | import copy 9 | 10 | 11 | # -------------------- 12 | # Model layers and helpers 13 | # -------------------- 14 | 15 | def create_masks(input_size, hidden_size, n_hidden, input_order='sequential', input_degrees=None): 16 | # MADE paper sec 4: 17 | # degrees of connections between layers -- ensure at most in_degree - 1 connections 18 | degrees = [] 19 | 20 | # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades); 21 | # else init input degrees based on strategy in input_order (sequential or random) 22 | if input_size>1: 23 | if input_order == 'sequential': 24 | degrees += [torch.arange(input_size)] if input_degrees is None else [input_degrees] 25 | for _ in range(n_hidden + 1): 26 | degrees += [torch.arange(hidden_size) % (input_size - 1)] 27 | degrees += [torch.arange(input_size) % input_size - 1] if input_degrees is None else [input_degrees % input_size - 1] 28 | 29 | elif input_order == 'random': 30 | degrees += [torch.randperm(input_size)] if input_degrees is None else [input_degrees] 31 | for _ in range(n_hidden + 1): 32 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1) 33 | degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))] 34 | min_prev_degree = min(degrees[-1].min().item(), input_size - 1) 35 | degrees += [torch.randint(min_prev_degree, input_size, (input_size,)) - 1] if input_degrees is None else [input_degrees - 1] 36 | else: 37 | degrees += [torch.zeros([1]).long()] 38 | for _ in range(n_hidden+1): 39 | degrees += [torch.zeros([hidden_size]).long()] 40 | degrees += [torch.zeros([input_size]).long()] 41 | # construct masks 42 | masks = [] 43 | for (d0, d1) in zip(degrees[:-1], degrees[1:]): 44 | masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()] 45 | 46 | return masks, degrees[0] 47 | 48 | #%% 49 | 50 | def create_masks_pmu(input_size, hidden_size, n_hidden, input_order='sequential', input_degrees=None): 51 | # MADE paper sec 4: 52 | # degrees of connections between layers -- ensure at most in_degree - 1 connections 53 | degrees = [] 54 | 55 | # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades); 56 | # else init input degrees based on strategy in input_order (sequential or random) 57 | if input_order == 'sequential': 58 | degrees += [torch.arange(input_size)] if input_degrees is None else [input_degrees] 59 | for _ in range(n_hidden + 1): 60 | degrees += [torch.arange(hidden_size) % (input_size - 1)] 61 | degrees += [torch.arange(input_size) % input_size - 1] if input_degrees is None else [input_degrees % input_size - 1] 62 | 63 | # construct masks 64 | masks = [] 65 | for (d0, d1) in zip(degrees[:-1], degrees[1:]): 66 | masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()] 67 | masks[0] = masks[0].repeat_interleave(3, dim=1) 68 | masks[-1] = masks[-1].repeat_interleave(3, dim=0) 69 | 70 | return masks, degrees[0] 71 | #%% 72 | class MaskedLinear(nn.Linear): 73 | """ MADE building block layer """ 74 | def __init__(self, input_size, n_outputs, mask, cond_label_size=None): 75 | super().__init__(input_size, n_outputs) 76 | 77 | self.register_buffer('mask', mask) 78 | 79 | self.cond_label_size = cond_label_size 80 | if cond_label_size is not None: 81 | self.cond_weight = nn.Parameter(torch.rand(n_outputs, cond_label_size) / math.sqrt(cond_label_size)) 82 | 83 | def forward(self, x, y=None): 84 | out = F.linear(x, self.weight * self.mask, self.bias) 85 | if y is not None: 86 | out = out + F.linear(y, self.cond_weight) 87 | return out 88 | 89 | def extra_repr(self): 90 | return 'in_features={}, out_features={}, bias={}'.format( 91 | self.in_features, self.out_features, self.bias is not None 92 | ) + (self.cond_label_size != None) * ', cond_features={}'.format(self.cond_label_size) 93 | 94 | 95 | class LinearMaskedCoupling(nn.Module): 96 | """ Modified RealNVP Coupling Layers per the MAF paper """ 97 | def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None): 98 | super().__init__() 99 | 100 | self.register_buffer('mask', mask) 101 | 102 | # scale function 103 | s_net = [nn.Linear(input_size + (cond_label_size if cond_label_size is not None else 0), hidden_size)] 104 | for _ in range(n_hidden): 105 | s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)] 106 | s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)] 107 | self.s_net = nn.Sequential(*s_net) 108 | 109 | # translation function 110 | self.t_net = copy.deepcopy(self.s_net) 111 | # replace Tanh with ReLU's per MAF paper 112 | for i in range(len(self.t_net)): 113 | if not isinstance(self.t_net[i], nn.Linear): self.t_net[i] = nn.ReLU() 114 | 115 | def forward(self, x, y=None): 116 | # apply mask 117 | mx = x * self.mask 118 | 119 | # run through model 120 | s = self.s_net(mx if y is None else torch.cat([y, mx], dim=1)) 121 | t = self.t_net(mx if y is None else torch.cat([y, mx], dim=1)) 122 | u = mx + (1 - self.mask) * (x - t) * torch.exp(-s) # cf RealNVP eq 8 where u corresponds to x (here we're modeling u) 123 | 124 | log_abs_det_jacobian = - (1 - self.mask) * s # log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob 125 | 126 | return u, log_abs_det_jacobian 127 | 128 | def inverse(self, u, y=None): 129 | # apply mask 130 | mu = u * self.mask 131 | 132 | # run through model 133 | s = self.s_net(mu if y is None else torch.cat([y, mu], dim=1)) 134 | t = self.t_net(mu if y is None else torch.cat([y, mu], dim=1)) 135 | x = mu + (1 - self.mask) * (u * s.exp() + t) # cf RealNVP eq 7 136 | 137 | log_abs_det_jacobian = (1 - self.mask) * s # log det dx/du 138 | 139 | return x, log_abs_det_jacobian 140 | 141 | 142 | class BatchNorm(nn.Module): 143 | """ RealNVP BatchNorm layer """ 144 | def __init__(self, input_size, momentum=0.9, eps=1e-5): 145 | super().__init__() 146 | self.momentum = momentum 147 | self.eps = eps 148 | 149 | self.log_gamma = nn.Parameter(torch.zeros(input_size)) 150 | self.beta = nn.Parameter(torch.zeros(input_size)) 151 | 152 | self.register_buffer('running_mean', torch.zeros(input_size)) 153 | self.register_buffer('running_var', torch.ones(input_size)) 154 | 155 | def forward(self, x, cond_y=None): 156 | if self.training: 157 | self.batch_mean = x.mean(0) 158 | self.batch_var = x.var(0) # note MAF paper uses biased variance estimate; ie x.var(0, unbiased=False) 159 | 160 | # update running mean 161 | self.running_mean.mul_(self.momentum).add_(self.batch_mean.data * (1 - self.momentum)) 162 | self.running_var.mul_(self.momentum).add_(self.batch_var.data * (1 - self.momentum)) 163 | 164 | mean = self.batch_mean 165 | var = self.batch_var 166 | else: 167 | mean = self.running_mean 168 | var = self.running_var 169 | 170 | # compute normalized input (cf original batch norm paper algo 1) 171 | x_hat = (x - mean) / torch.sqrt(var + self.eps) 172 | y = self.log_gamma.exp() * x_hat + self.beta 173 | 174 | # compute log_abs_det_jacobian (cf RealNVP paper) 175 | log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps) 176 | # print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format( 177 | # (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean())) 178 | return y, log_abs_det_jacobian.expand_as(x) 179 | 180 | def inverse(self, y, cond_y=None): 181 | if self.training: 182 | mean = self.batch_mean 183 | var = self.batch_var 184 | else: 185 | mean = self.running_mean 186 | var = self.running_var 187 | 188 | x_hat = (y - self.beta) * torch.exp(-self.log_gamma) 189 | x = x_hat * torch.sqrt(var + self.eps) + mean 190 | 191 | log_abs_det_jacobian = 0.5 * torch.log(var + self.eps) - self.log_gamma 192 | 193 | return x, log_abs_det_jacobian.expand_as(x) 194 | 195 | 196 | class FlowSequential(nn.Sequential): 197 | """ Container for layers of a normalizing flow """ 198 | def forward(self, x, y): 199 | sum_log_abs_det_jacobians = 0 200 | for module in self: 201 | x, log_abs_det_jacobian = module(x, y) 202 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian 203 | return x, sum_log_abs_det_jacobians 204 | 205 | def inverse(self, u, y): 206 | sum_log_abs_det_jacobians = 0 207 | for module in reversed(self): 208 | u, log_abs_det_jacobian = module.inverse(u, y) 209 | sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + log_abs_det_jacobian 210 | return u, sum_log_abs_det_jacobians 211 | 212 | # -------------------- 213 | # Models 214 | # -------------------- 215 | 216 | class MADE(nn.Module): 217 | def __init__(self, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', input_degrees=None): 218 | """ 219 | Args: 220 | input_size -- scalar; dim of inputs 221 | hidden_size -- scalar; dim of hidden layers 222 | n_hidden -- scalar; number of hidden layers 223 | activation -- str; activation function to use 224 | input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random) 225 | or the order flipped from the previous layer in a stack of mades 226 | conditional -- bool; whether model is conditional 227 | """ 228 | super().__init__() 229 | # base distribution for calculation of log prob under the model 230 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 231 | self.register_buffer('base_dist_var', torch.ones(input_size)) 232 | 233 | # create masks 234 | masks, self.input_degrees = create_masks(input_size, hidden_size, n_hidden, input_order, input_degrees) 235 | 236 | # setup activation 237 | if activation == 'relu': 238 | activation_fn = nn.ReLU() 239 | elif activation == 'tanh': 240 | activation_fn = nn.Tanh() 241 | else: 242 | raise ValueError('Check activation function.') 243 | 244 | # construct model 245 | self.net_input = MaskedLinear(input_size, hidden_size, masks[0], cond_label_size) 246 | self.net = [] 247 | for m in masks[1:-1]: 248 | self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)] 249 | self.net += [activation_fn, MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2,1))] 250 | self.net = nn.Sequential(*self.net) 251 | 252 | @property 253 | def base_dist(self): 254 | return D.Normal(self.base_dist_mean, self.base_dist_var) 255 | 256 | def forward(self, x, y=None): 257 | # MAF eq 4 -- return mean and log std 258 | m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=1) 259 | u = (x - m) * torch.exp(-loga) 260 | # MAF eq 5 261 | log_abs_det_jacobian = - loga 262 | return u, log_abs_det_jacobian 263 | 264 | def inverse(self, u, y=None, sum_log_abs_det_jacobians=None): 265 | # MAF eq 3 266 | D = u.shape[1] 267 | x = torch.zeros_like(u) 268 | # run through reverse model 269 | for i in self.input_degrees: 270 | m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=1) 271 | x[:,i] = u[:,i] * torch.exp(loga[:,i]) + m[:,i] 272 | log_abs_det_jacobian = loga 273 | return x, log_abs_det_jacobian 274 | 275 | def log_prob(self, x, y=None): 276 | u, log_abs_det_jacobian = self.forward(x, y) 277 | return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=1) 278 | 279 | 280 | class MADE_Full(nn.Module): 281 | def __init__(self, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', input_degrees=None): 282 | """ 283 | Args: 284 | input_size -- scalar; dim of inputs 285 | hidden_size -- scalar; dim of hidden layers 286 | n_hidden -- scalar; number of hidden layers 287 | activation -- str; activation function to use 288 | input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random) 289 | or the order flipped from the previous layer in a stack of mades 290 | conditional -- bool; whether model is conditional 291 | """ 292 | super().__init__() 293 | # base distribution for calculation of log prob under the model 294 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 295 | self.register_buffer('base_dist_var', torch.ones(input_size)) 296 | 297 | # create masks 298 | masks, self.input_degrees = create_masks_pmu(int(input_size/3), hidden_size, n_hidden, input_order, input_degrees) 299 | 300 | # setup activation 301 | if activation == 'relu': 302 | activation_fn = nn.ReLU() 303 | elif activation == 'tanh': 304 | activation_fn = nn.Tanh() 305 | else: 306 | raise ValueError('Check activation function.') 307 | 308 | # construct model 309 | self.net_input = MaskedLinear(input_size, hidden_size, masks[0], cond_label_size) 310 | self.net = [] 311 | for m in masks[1:-1]: 312 | self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)] 313 | self.net += [activation_fn, MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2,1))] 314 | self.net = nn.Sequential(*self.net) 315 | 316 | @property 317 | def base_dist(self): 318 | return D.Normal(self.base_dist_mean, self.base_dist_var) 319 | 320 | def forward(self, x, y=None): 321 | # MAF eq 4 -- return mean and log std 322 | m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=1) 323 | u = (x - m) * torch.exp(-loga) 324 | # MAF eq 5 325 | log_abs_det_jacobian = - loga 326 | return u, log_abs_det_jacobian 327 | 328 | def log_prob(self, x, y=None): 329 | u, log_abs_det_jacobian = self.forward(x, y) 330 | return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=1) 331 | 332 | 333 | class MAF(nn.Module): 334 | def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', batch_norm=True): 335 | super().__init__() 336 | # base distribution for calculation of log prob under the model 337 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 338 | self.register_buffer('base_dist_var', torch.ones(input_size)) 339 | 340 | # construct model 341 | modules = [] 342 | self.input_degrees = None 343 | for i in range(n_blocks): 344 | modules += [MADE(input_size, hidden_size, n_hidden, cond_label_size, activation, input_order, self.input_degrees)] 345 | self.input_degrees = modules[-1].input_degrees.flip(0) 346 | modules += batch_norm * [BatchNorm(input_size)] 347 | 348 | self.net = FlowSequential(*modules) 349 | 350 | @property 351 | def base_dist(self): 352 | return D.Normal(self.base_dist_mean, self.base_dist_var) 353 | 354 | def forward(self, x, y=None): 355 | return self.net(x, y) 356 | 357 | def inverse(self, u, y=None): 358 | return self.net.inverse(u, y) 359 | 360 | def log_prob(self, x, y=None): 361 | u, sum_log_abs_det_jacobians = self.forward(x, y) 362 | return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=1) 363 | 364 | 365 | class MAF_Full(nn.Module): 366 | def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, activation='relu', input_order='sequential', batch_norm=True): 367 | super().__init__() 368 | # base distribution for calculation of log prob under the model 369 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 370 | self.register_buffer('base_dist_var', torch.ones(input_size)) 371 | 372 | # construct model 373 | modules = [] 374 | self.input_degrees = None 375 | for i in range(n_blocks): 376 | modules += [MADE_Full(input_size, hidden_size, n_hidden, cond_label_size, activation, input_order, self.input_degrees)] 377 | self.input_degrees = modules[-1].input_degrees.flip(0) 378 | modules += batch_norm * [BatchNorm(input_size)] 379 | 380 | self.net = FlowSequential(*modules) 381 | 382 | @property 383 | def base_dist(self): 384 | return D.Normal(self.base_dist_mean, self.base_dist_var) 385 | 386 | def forward(self, x, y=None): 387 | return self.net(x, y) 388 | 389 | def inverse(self, u, y=None): 390 | return self.net.inverse(u, y) 391 | 392 | def log_prob(self, x, y=None): 393 | u, sum_log_abs_det_jacobians = self.forward(x, y) 394 | return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=1) 395 | 396 | 397 | 398 | class RealNVP(nn.Module): 399 | def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, batch_norm=True): 400 | super().__init__() 401 | 402 | # base distribution for calculation of log prob under the model 403 | self.register_buffer('base_dist_mean', torch.zeros(input_size)) 404 | self.register_buffer('base_dist_var', torch.ones(input_size)) 405 | 406 | # construct model 407 | modules = [] 408 | mask = torch.arange(input_size).float() % 2 409 | for i in range(n_blocks): 410 | modules += [LinearMaskedCoupling(input_size, hidden_size, n_hidden, mask, cond_label_size)] 411 | mask = 1 - mask 412 | modules += batch_norm * [BatchNorm(input_size)] 413 | 414 | self.net = FlowSequential(*modules) 415 | 416 | @property 417 | def base_dist(self): 418 | return D.Normal(self.base_dist_mean, self.base_dist_var) 419 | 420 | def forward(self, x, y=None): 421 | return self.net(x, y) 422 | 423 | def inverse(self, u, y=None): 424 | return self.net.inverse(u, y) 425 | 426 | def log_prob(self, x, y=None): 427 | u, sum_log_abs_det_jacobians = self.forward(x, y) 428 | return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=1) 429 | -------------------------------------------------------------------------------- /models/RNN.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | import torch.nn as nn 4 | from functools import partial 5 | 6 | class RecurrentEncoder(nn.Module): 7 | """Recurrent encoder""" 8 | 9 | def __init__(self, n_features, latent_dim, rnn): 10 | super().__init__() 11 | 12 | self.rec_enc1 = rnn(n_features, latent_dim, batch_first=True) 13 | 14 | def forward(self, x): 15 | _, h_n = self.rec_enc1(x) 16 | 17 | return h_n 18 | 19 | class RecurrentDecoder(nn.Module): 20 | """Recurrent decoder for RNN and GRU""" 21 | 22 | def __init__(self, latent_dim, n_features, rnn_cell, device): 23 | super().__init__() 24 | 25 | self.n_features = n_features 26 | self.device = device 27 | self.rec_dec1 = rnn_cell(n_features, latent_dim) 28 | self.dense_dec1 = nn.Linear(latent_dim, n_features) 29 | 30 | def forward(self, h_0, seq_len): 31 | # Initialize output 32 | x = torch.tensor([], device = self.device) 33 | 34 | # Squeezing 35 | h_i = h_0.squeeze() 36 | 37 | # Reconstruct first element with encoder output 38 | x_i = self.dense_dec1(h_i) 39 | 40 | # Reconstruct remaining elements 41 | for i in range(0, seq_len): 42 | h_i = self.rec_dec1(x_i, h_i) 43 | x_i = self.dense_dec1(h_i) 44 | x = torch.cat([x, x_i], axis=1) 45 | 46 | return x.view(-1, seq_len, self.n_features) 47 | 48 | 49 | class RecurrentDecoderLSTM(nn.Module): 50 | """Recurrent decoder LSTM""" 51 | 52 | def __init__(self, latent_dim, n_features, rnn_cell, device): 53 | super().__init__() 54 | 55 | self.n_features = n_features 56 | self.device = device 57 | self.rec_dec1 = rnn_cell(n_features, latent_dim) 58 | self.dense_dec1 = nn.Linear(latent_dim, n_features) 59 | 60 | def forward(self, h_0, seq_len): 61 | # Initialize output 62 | x = torch.tensor([], device = self.device) 63 | 64 | # Squeezing 65 | h_i = [h.squeeze() for h in h_0] 66 | 67 | # Reconstruct first element with encoder output 68 | x_i = self.dense_dec1(h_i[0]) 69 | 70 | # Reconstruct remaining elements 71 | for i in range(0, seq_len): 72 | h_i = self.rec_dec1(x_i, h_i) 73 | x_i = self.dense_dec1(h_i[0]) 74 | x = torch.cat([x, x_i], axis = 1) 75 | 76 | return x.view(-1, seq_len, self.n_features) 77 | 78 | 79 | class RecurrentAE(nn.Module): 80 | """Recurrent autoencoder""" 81 | 82 | def __init__(self, n_features, latent_dim, device): 83 | super().__init__() 84 | 85 | # Encoder and decoder argsuration 86 | self.rnn, self.rnn_cell = nn.LSTM, nn.LSTMCell 87 | self.decoder = RecurrentDecoderLSTM 88 | self.latent_dim = latent_dim 89 | self.n_features = n_features 90 | self.device = device 91 | 92 | # Encoder and decoder 93 | self.encoder = RecurrentEncoder(self.n_features, self.latent_dim, self.rnn) 94 | self.decoder = self.decoder(self.latent_dim, self.n_features, self.rnn_cell, self.device) 95 | 96 | def forward(self, x): 97 | # x: N X K X L X D 98 | seq_len = x.shape[1] 99 | h_n = self.encoder(x) 100 | out = self.decoder(h_n, seq_len) 101 | 102 | return torch.flip(out, [1]) 103 | 104 | # %% 105 | -------------------------------------------------------------------------------- /models/__pycache__/DROCC.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/DROCC.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/DeepSAD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/DeepSAD.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/DeepSAD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/DeepSAD.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/GAN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/GAN.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/GANF.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/GANF.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/GDN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/GDN.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/NF.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/NF.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/NF.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/NF.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/PMUNF.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/PMUNF.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/PMUNF.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/PMUNF.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/RNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/RNN.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/RNN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/RNN.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/graph_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnyanDai/GANF/bce38333b109f325a403dae9aff4987ae5bd6e1f/models/__pycache__/graph_layer.cpython-36.pyc -------------------------------------------------------------------------------- /models/graph_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter, Linear, Sequential, BatchNorm1d, ReLU 3 | import torch.nn.functional as F 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 6 | 7 | from torch_geometric.nn.inits import glorot, zeros 8 | 9 | class GraphLayer(MessagePassing): 10 | def __init__(self, in_channels, out_channels, heads=1, concat=True, 11 | negative_slope=0.2, dropout=0, bias=True, inter_dim=-1,**kwargs): 12 | super(GraphLayer, self).__init__(aggr='add', **kwargs) 13 | 14 | self.in_channels = in_channels 15 | self.out_channels = out_channels 16 | self.heads = heads 17 | self.concat = concat 18 | self.negative_slope = negative_slope 19 | self.dropout = dropout 20 | 21 | self.__alpha__ = None 22 | 23 | self.lin = Linear(in_channels, heads * out_channels, bias=False) 24 | 25 | self.att_i = Parameter(torch.Tensor(1, heads, out_channels)) 26 | self.att_j = Parameter(torch.Tensor(1, heads, out_channels)) 27 | self.att_em_i = Parameter(torch.Tensor(1, heads, out_channels)) 28 | self.att_em_j = Parameter(torch.Tensor(1, heads, out_channels)) 29 | 30 | if bias and concat: 31 | self.bias = Parameter(torch.Tensor(heads * out_channels)) 32 | elif bias and not concat: 33 | self.bias = Parameter(torch.Tensor(out_channels)) 34 | else: 35 | self.register_parameter('bias', None) 36 | 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | glorot(self.lin.weight) 41 | glorot(self.att_i) 42 | glorot(self.att_j) 43 | 44 | zeros(self.att_em_i) 45 | zeros(self.att_em_j) 46 | 47 | zeros(self.bias) 48 | 49 | 50 | 51 | def forward(self, x, edge_index, embedding, return_attention_weights=False): 52 | """""" 53 | if torch.is_tensor(x): 54 | x = self.lin(x) 55 | x = (x, x) 56 | else: 57 | x = (self.lin(x[0]), self.lin(x[1])) 58 | 59 | edge_index, _ = remove_self_loops(edge_index) 60 | edge_index, _ = add_self_loops(edge_index, 61 | num_nodes=x[1].size(self.node_dim)) 62 | 63 | out = self.propagate(edge_index, x=x, embedding=embedding, edges=edge_index, 64 | return_attention_weights=return_attention_weights) 65 | 66 | if self.concat: 67 | out = out.view(-1, self.heads * self.out_channels) 68 | else: 69 | out = out.mean(dim=1) 70 | 71 | if self.bias is not None: 72 | out = out + self.bias 73 | 74 | if return_attention_weights: 75 | alpha, self.__alpha__ = self.__alpha__, None 76 | return out, (edge_index, alpha) 77 | else: 78 | return out 79 | 80 | def message(self, x_i, x_j, edge_index_i, size_i, 81 | embedding, 82 | edges, 83 | return_attention_weights): 84 | 85 | x_i = x_i.view(-1, self.heads, self.out_channels) 86 | x_j = x_j.view(-1, self.heads, self.out_channels) 87 | 88 | if embedding is not None: 89 | embedding_i, embedding_j = embedding[edge_index_i], embedding[edges[0]] 90 | embedding_i = embedding_i.unsqueeze(1).repeat(1,self.heads,1) 91 | embedding_j = embedding_j.unsqueeze(1).repeat(1,self.heads,1) 92 | 93 | key_i = torch.cat((x_i, embedding_i), dim=-1) 94 | key_j = torch.cat((x_j, embedding_j), dim=-1) 95 | 96 | 97 | 98 | cat_att_i = torch.cat((self.att_i, self.att_em_i), dim=-1) 99 | cat_att_j = torch.cat((self.att_j, self.att_em_j), dim=-1) 100 | 101 | alpha = (key_i * cat_att_i).sum(-1) + (key_j * cat_att_j).sum(-1) 102 | 103 | 104 | alpha = alpha.view(-1, self.heads, 1) 105 | 106 | 107 | alpha = F.leaky_relu(alpha, self.negative_slope) 108 | alpha = softmax(alpha, edge_index_i, num_nodes=size_i) 109 | 110 | if return_attention_weights: 111 | self.__alpha__ = alpha 112 | 113 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 114 | 115 | return x_j * alpha.view(-1, self.heads, 1) 116 | 117 | 118 | 119 | def __repr__(self): 120 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 121 | self.in_channels, 122 | self.out_channels, self.heads) 123 | -------------------------------------------------------------------------------- /train_traffic.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import argparse 4 | import torch 5 | from models.GANF import GANF 6 | import numpy as np 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | # files 11 | parser.add_argument('--data_dir', type=str, 12 | default='./data', help='Location of datasets.') 13 | parser.add_argument('--output_dir', type=str, 14 | default='./checkpoint/model') 15 | parser.add_argument('--name',default='traffic') 16 | parser.add_argument('--dataset', type=str, default='metr-la') 17 | # restore 18 | parser.add_argument('--graph', type=str, default='None') 19 | parser.add_argument('--model', type=str, default='None') 20 | parser.add_argument('--seed', type=int, default=10, help='Random seed to use.') 21 | # model parameters 22 | parser.add_argument('--n_blocks', type=int, default=6, help='Number of blocks to stack in a model (MADE in MAF; Coupling+BN in RealNVP).') 23 | parser.add_argument('--n_components', type=int, default=1, help='Number of Gaussian clusters for mixture of gaussians models.') 24 | parser.add_argument('--hidden_size', type=int, default=32, help='Hidden layer size for MADE (and each MADE block in an MAF).') 25 | parser.add_argument('--n_hidden', type=int, default=1, help='Number of hidden layers in each MADE.') 26 | parser.add_argument('--batch_norm', type=bool, default=False) 27 | # training params 28 | parser.add_argument('--batch_size', type=int, default=64) 29 | parser.add_argument('--weight_decay', type=float, default=5e-4) 30 | parser.add_argument('--n_epochs', type=int, default=20) 31 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.') 32 | parser.add_argument('--log_interval', type=int, default=5, help='How often to show loss statistics and save samples.') 33 | 34 | parser.add_argument('--h_tol', type=float, default=1e-6) 35 | parser.add_argument('--rho_max', type=float, default=1e16) 36 | parser.add_argument('--max_iter', type=int, default=20) 37 | parser.add_argument('--lambda1', type=float, default=0.0) 38 | parser.add_argument('--rho_init', type=float, default=1.0) 39 | parser.add_argument('--alpha_init', type=float, default=0.0) 40 | 41 | args = parser.parse_known_args()[0] 42 | args.cuda = torch.cuda.is_available() 43 | device = torch.device("cuda" if args.cuda else "cpu") 44 | 45 | 46 | print(args) 47 | import random 48 | import numpy as np 49 | random.seed(args.seed) 50 | np.random.seed(args.seed) 51 | torch.manual_seed(args.seed) 52 | if args.cuda: 53 | torch.cuda.manual_seed(args.seed) 54 | #%% 55 | print("Loading dataset") 56 | from dataset import load_traffic 57 | 58 | train_loader, val_loader, test_loader, n_sensor = load_traffic("{}/{}.h5".format(args.data_dir,args.dataset), \ 59 | args.batch_size) 60 | #%% 61 | 62 | rho = args.rho_init 63 | alpha = args.alpha_init 64 | lambda1 = args.lambda1 65 | h_A_old = np.inf 66 | 67 | 68 | max_iter = args.max_iter 69 | rho_max = args.rho_max 70 | h_tol = args.h_tol 71 | epoch = 0 72 | 73 | # initialize A 74 | if args.graph != 'None': 75 | init = torch.load(args.graph).to(device).abs() 76 | print("Load graph from "+args.graph) 77 | else: 78 | from torch.nn.init import xavier_uniform_ 79 | init = torch.zeros([n_sensor, n_sensor]) 80 | init = xavier_uniform_(init).abs() 81 | init = init.fill_diagonal_(0.0) 82 | A = torch.tensor(init, requires_grad=True, device=device) 83 | 84 | #%% 85 | model = GANF(args.n_blocks, 1, args.hidden_size, args.n_hidden, dropout=0.0, batch_norm=args.batch_norm) 86 | model = model.to(device) 87 | 88 | if args.model != 'None': 89 | model.load_state_dict(torch.load(args.model)) 90 | print('Load model from '+args.model) 91 | #%% 92 | from torch.nn.utils import clip_grad_value_ 93 | save_path = os.path.join(args.output_dir,args.name) 94 | if not os.path.exists(save_path): 95 | os.makedirs(save_path) 96 | loss_best = 100 97 | for _ in range(max_iter): 98 | 99 | while rho < rho_max: 100 | lr = args.lr #* np.math.pow(0.1, epoch // 100) 101 | optimizer = torch.optim.Adam([ 102 | {'params':model.parameters(), 'weight_decay':args.weight_decay}, 103 | {'params': [A]}], lr=lr, weight_decay=0.0) 104 | # train 105 | 106 | for _ in range(args.n_epochs): 107 | 108 | # train 109 | loss_train = [] 110 | epoch += 1 111 | model.train() 112 | for x in train_loader: 113 | x = x.to(device) 114 | 115 | optimizer.zero_grad() 116 | A_hat = torch.divide(A.T,A.sum(dim=1).detach()).T 117 | loss = -model(x, A_hat) 118 | h = torch.trace(torch.matrix_exp(A_hat*A_hat)) - n_sensor 119 | total_loss = loss + 0.5 * rho * h * h + alpha * h 120 | 121 | total_loss.backward() 122 | clip_grad_value_(model.parameters(), 1) 123 | optimizer.step() 124 | loss_train.append(loss.item()) 125 | A.data.copy_(torch.clamp(A.data, min=0, max=1)) 126 | 127 | # evaluate 128 | model.eval() 129 | loss_val = [] 130 | with torch.no_grad(): 131 | for x in val_loader: 132 | 133 | x = x.to(device) 134 | loss = -model(x,A_hat.data) 135 | loss_val.append(loss.item()) 136 | 137 | print('Epoch: {}, train -log_prob: {:.2f}, test -log_prob: {:.2f}, h: {}'\ 138 | .format(epoch, np.mean(loss_train), np.mean(loss_val), h.item())) 139 | 140 | if np.mean(loss_val) < loss_best: 141 | loss_best = np.mean(loss_val) 142 | print("save model {} epoch".format(epoch)) 143 | torch.save(A.data,os.path.join(save_path, "graph_best.pt")) 144 | torch.save(model.state_dict(), os.path.join(save_path, "{}_best.pt".format(args.name))) 145 | 146 | 147 | print('rho: {}, alpha {}, h {}'.format(rho, alpha, h.item())) 148 | print('===========================================') 149 | torch.save(A.data,os.path.join(save_path, "graph_{}.pt".format(epoch))) 150 | torch.save(model.state_dict(), os.path.join(save_path, "{}_{}.pt".format(args.name, epoch))) 151 | 152 | del optimizer 153 | torch.cuda.empty_cache() 154 | 155 | if h.item() > 0.5 * h_A_old: 156 | rho *= 10 157 | else: 158 | break 159 | 160 | h_A_old = h.item() 161 | alpha += rho*h.item() 162 | 163 | if h_A_old <= h_tol or rho >=rho_max: 164 | break 165 | 166 | 167 | # %% 168 | lr = args.lr * 0.1 169 | optimizer = torch.optim.Adam([ 170 | {'params':model.parameters(), 'weight_decay':args.weight_decay}, 171 | {'params': [A]}], lr=lr, weight_decay=0.0) 172 | # train 173 | 174 | for _ in range(100): 175 | loss_train = [] 176 | epoch += 1 177 | model.train() 178 | for x in train_loader: 179 | x = x.to(device) 180 | 181 | optimizer.zero_grad() 182 | A_hat = torch.divide(A.T,A.sum(dim=1).detach()).T 183 | loss = -model(x, A_hat) 184 | h = torch.trace(torch.matrix_exp(A_hat*A_hat)) - n_sensor 185 | total_loss = loss + 0.5 * rho * h * h + alpha * h 186 | 187 | total_loss.backward() 188 | clip_grad_value_(model.parameters(), 1) 189 | optimizer.step() 190 | loss_train.append(loss.item()) 191 | A.data.copy_(torch.clamp(A.data, min=0, max=1)) 192 | 193 | model.eval() 194 | loss_val = [] 195 | print(A.max()) 196 | with torch.no_grad(): 197 | for x in val_loader: 198 | 199 | x = x.to(device) 200 | loss = -model(x,A_hat.data) 201 | loss_val.append(loss.item()) 202 | 203 | print('Epoch: {}, train -log_prob: {:.2f}, test -log_prob: {:.2f}, h: {}'\ 204 | .format(epoch, np.mean(loss_train), np.mean(loss_val), h.item())) 205 | 206 | if np.mean(loss_val) < loss_best: 207 | loss_best = np.mean(loss_val) 208 | print("save model {} epoch".format(epoch)) 209 | torch.save(A.data,os.path.join(save_path, "graph_best.pt")) 210 | torch.save(model.state_dict(), os.path.join(save_path, "{}_best.pt".format(args.name))) 211 | 212 | if epoch % args.log_interval==0: 213 | torch.save(A.data,os.path.join(save_path, "graph_{}.pt".format(epoch))) 214 | torch.save(model.state_dict(), os.path.join(save_path, "{}_{}.pt".format(args.name, epoch))) 215 | 216 | #%% 217 | -------------------------------------------------------------------------------- /train_water.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import argparse 4 | import torch 5 | from models.GANF import GANF 6 | import numpy as np 7 | from sklearn.metrics import roc_auc_score 8 | # from data import fetch_dataloaders 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | # files 13 | parser.add_argument('--data_dir', type=str, 14 | default='./data/SWaT_Dataset_Attack_v0.csv', help='Location of datasets.') 15 | parser.add_argument('--output_dir', type=str, 16 | default='./checkpoint/model') 17 | parser.add_argument('--name',default='GANF_Water') 18 | # restore 19 | parser.add_argument('--graph', type=str, default='None') 20 | parser.add_argument('--model', type=str, default='None') 21 | parser.add_argument('--seed', type=int, default=18, help='Random seed to use.') 22 | # made parameters 23 | parser.add_argument('--n_blocks', type=int, default=1, help='Number of blocks to stack in a model (MADE in MAF; Coupling+BN in RealNVP).') 24 | parser.add_argument('--n_components', type=int, default=1, help='Number of Gaussian clusters for mixture of gaussians models.') 25 | parser.add_argument('--hidden_size', type=int, default=32, help='Hidden layer size for MADE (and each MADE block in an MAF).') 26 | parser.add_argument('--n_hidden', type=int, default=1, help='Number of hidden layers in each MADE.') 27 | parser.add_argument('--batch_norm', type=bool, default=False) 28 | # training params 29 | parser.add_argument('--batch_size', type=int, default=512) 30 | parser.add_argument('--weight_decay', type=float, default=5e-4) 31 | parser.add_argument('--n_epochs', type=int, default=1) 32 | parser.add_argument('--lr', type=float, default=2e-3, help='Learning rate.') 33 | parser.add_argument('--log_interval', type=int, default=5, help='How often to show loss statistics and save samples.') 34 | 35 | parser.add_argument('--h_tol', type=float, default=1e-4) 36 | parser.add_argument('--rho_max', type=float, default=1e16) 37 | parser.add_argument('--max_iter', type=int, default=20) 38 | parser.add_argument('--lambda1', type=float, default=0.0) 39 | parser.add_argument('--rho_init', type=float, default=1.0) 40 | parser.add_argument('--alpha_init', type=float, default=0.0) 41 | 42 | args = parser.parse_known_args()[0] 43 | args.cuda = torch.cuda.is_available() 44 | device = torch.device("cuda" if args.cuda else "cpu") 45 | 46 | 47 | print(args) 48 | import random 49 | import numpy as np 50 | random.seed(args.seed) 51 | np.random.seed(args.seed) 52 | torch.manual_seed(args.seed) 53 | if args.cuda: 54 | torch.cuda.manual_seed(args.seed) 55 | #%% 56 | print("Loading dataset") 57 | from dataset import load_water 58 | 59 | train_loader, val_loader, test_loader, n_sensor = load_water(args.data_dir, \ 60 | args.batch_size) 61 | #%% 62 | 63 | rho = args.rho_init 64 | alpha = args.alpha_init 65 | lambda1 = args.lambda1 66 | h_A_old = np.inf 67 | 68 | 69 | max_iter = args.max_iter 70 | rho_max = args.rho_max 71 | h_tol = args.h_tol 72 | epoch = 0 73 | 74 | # initialize A 75 | if args.graph != 'None': 76 | init = torch.load(args.graph).to(device).abs() 77 | print("Load graph from "+args.graph) 78 | else: 79 | from torch.nn.init import xavier_uniform_ 80 | init = torch.zeros([n_sensor, n_sensor]) 81 | init = xavier_uniform_(init).abs() 82 | init = init.fill_diagonal_(0.0) 83 | A = torch.tensor(init, requires_grad=True, device=device) 84 | 85 | #%% 86 | model = GANF(args.n_blocks, 1, args.hidden_size, args.n_hidden, dropout=0.0, batch_norm=args.batch_norm) 87 | model = model.to(device) 88 | 89 | if args.model != 'None': 90 | model.load_state_dict(torch.load(args.model)) 91 | print('Load model from '+args.model) 92 | #%% 93 | from torch.nn.utils import clip_grad_value_ 94 | import seaborn as sns 95 | import matplotlib.pyplot as plt 96 | save_path = os.path.join(args.output_dir,args.name) 97 | if not os.path.exists(save_path): 98 | os.makedirs(save_path) 99 | 100 | 101 | loss_best = 100 102 | 103 | for _ in range(max_iter): 104 | 105 | while rho < rho_max: 106 | lr = args.lr 107 | optimizer = torch.optim.Adam([ 108 | {'params':model.parameters(), 'weight_decay':args.weight_decay}, 109 | {'params': [A]}], lr=lr, weight_decay=0.0) 110 | 111 | for _ in range(args.n_epochs): 112 | 113 | # train iteration 114 | loss_train = [] 115 | epoch += 1 116 | model.train() 117 | for x in train_loader: 118 | x = x.to(device) 119 | 120 | optimizer.zero_grad() 121 | loss = -model(x, A) 122 | h = torch.trace(torch.matrix_exp( A* A)) - n_sensor 123 | total_loss = loss + 0.5 * rho * h * h + alpha * h 124 | 125 | total_loss.backward() 126 | clip_grad_value_(model.parameters(), 1) 127 | optimizer.step() 128 | loss_train.append(loss.item()) 129 | A.data.copy_(torch.clamp(A.data, min=0, max=1)) 130 | 131 | 132 | # evaluate iteration 133 | model.eval() 134 | loss_val = [] 135 | with torch.no_grad(): 136 | for x in val_loader: 137 | 138 | x = x.to(device) 139 | loss = -model.test(x, A.data).cpu().numpy() 140 | loss_val.append(loss) 141 | loss_val = np.concatenate(loss_val) 142 | 143 | loss_test = [] 144 | with torch.no_grad(): 145 | for x in test_loader: 146 | 147 | x = x.to(device) 148 | loss = -model.test(x, A.data).cpu().numpy() 149 | loss_test.append(loss) 150 | loss_test = np.concatenate(loss_test) 151 | 152 | print(loss_val.max(), loss_val.min(), loss_test.max(), loss_test.min()) 153 | 154 | loss_val = np.nan_to_num(loss_val) 155 | loss_test = np.nan_to_num(loss_test) 156 | roc_val = roc_auc_score(np.asarray(val_loader.dataset.label.values,dtype=int),loss_val) 157 | roc_test = roc_auc_score(np.asarray(test_loader.dataset.label.values,dtype=int),loss_test) 158 | print('Epoch: {}, train -log_prob: {:.2f}, test -log_prob: {:.2f}, roc_val: {:.4f}, roc_test: {:.4f} ,h: {}'\ 159 | .format(epoch, np.mean(loss_train), np.mean(loss_val), roc_val, roc_test, h.item())) 160 | 161 | print('rho: {}, alpha {}, h {}'.format(rho, alpha, h.item())) 162 | print('===========================================') 163 | torch.save(A.data,os.path.join(save_path, "graph_{}.pt".format(epoch))) 164 | torch.save(model.state_dict(), os.path.join(save_path, "{}_{}.pt".format(args.name, epoch))) 165 | 166 | del optimizer 167 | torch.cuda.empty_cache() 168 | 169 | if h.item() > 0.5 * h_A_old: 170 | rho *= 10 171 | else: 172 | break 173 | 174 | 175 | h_A_old = h.item() 176 | alpha += rho*h.item() 177 | 178 | if h_A_old <= h_tol or rho >=rho_max: 179 | break 180 | 181 | 182 | # %% 183 | lr = args.lr 184 | optimizer = torch.optim.Adam([ 185 | {'params':model.parameters(), 'weight_decay':args.weight_decay}, 186 | {'params': [A]}], lr=lr, weight_decay=0.0) 187 | 188 | for _ in range(30): 189 | loss_train = [] 190 | epoch += 1 191 | model.train() 192 | for x in train_loader: 193 | x = x.to(device) 194 | 195 | optimizer.zero_grad() 196 | loss = -model(x, A) 197 | h = torch.trace(torch.matrix_exp(A*A)) - n_sensor 198 | total_loss = loss + 0.5 * rho * h * h + alpha * h 199 | 200 | total_loss.backward() 201 | clip_grad_value_(model.parameters(), 1) 202 | optimizer.step() 203 | loss_train.append(loss.item()) 204 | A.data.copy_(torch.clamp(A.data, min=0, max=1)) 205 | 206 | # eval 207 | model.eval() 208 | loss_val = [] 209 | with torch.no_grad(): 210 | for x in val_loader: 211 | 212 | x = x.to(device) 213 | loss = -model.test(x, A.data).cpu().numpy() 214 | loss_val.append(loss) 215 | loss_val = np.concatenate(loss_val) 216 | 217 | loss_test = [] 218 | with torch.no_grad(): 219 | for x in test_loader: 220 | 221 | x = x.to(device) 222 | loss = -model.test(x, A.data).cpu().numpy() 223 | loss_test.append(loss) 224 | loss_test = np.concatenate(loss_test) 225 | 226 | loss_val = np.nan_to_num(loss_val) 227 | loss_test = np.nan_to_num(loss_test) 228 | roc_val = roc_auc_score(np.asarray(val_loader.dataset.label.values,dtype=int),loss_val) 229 | roc_test = roc_auc_score(np.asarray(test_loader.dataset.label.values,dtype=int),loss_test) 230 | print('Epoch: {}, train -log_prob: {:.2f}, test -log_prob: {:.2f}, roc_val: {:.4f}, roc_test: {:.4f} ,h: {}'\ 231 | .format(epoch, np.mean(loss_train), np.mean(loss_val), roc_val, roc_test, h.item())) 232 | 233 | if np.mean(loss_val) < loss_best: 234 | loss_best = np.mean(loss_val) 235 | print("save model {} epoch".format(epoch)) 236 | torch.save(A.data,os.path.join(save_path, "graph_best.pt")) 237 | torch.save(model.state_dict(), os.path.join(save_path, "{}_best.pt".format(args.name))) 238 | 239 | if epoch % args.log_interval==0: 240 | torch.save(A.data,os.path.join(save_path, "graph_{}.pt".format(epoch))) 241 | torch.save(model.state_dict(), os.path.join(save_path, "{}_{}.pt".format(args.name, epoch))) -------------------------------------------------------------------------------- /train_water.sh: -------------------------------------------------------------------------------- 1 | for seed in {18..20} 2 | do 3 | python -u train_water.py\ 4 | --seed=${seed}\ 5 | --name=GANF_water_seed_${seed} 6 | done 7 | 8 | 9 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | 4 | def h(A): 5 | return torch.trace(torch.matrix_exp(A*A)) - A.shape[0] 6 | 7 | def normalize(A): 8 | D = A.sum(dim=0) 9 | D_inv = D.pow_(-1) 10 | D_inv.masked_fill_(D_inv == float('inf'), 0) 11 | 12 | return A * D_inv 13 | 14 | def thresholding(A, thre): 15 | return torch.where(A.abs()>thre, A, torch.scalar_tensor(0.0, dtype=torch.float32, device=A.device)) 16 | 17 | def binarize(A, thre): 18 | return torch.where(A.abs()>thre, 1.0, 0.0) 19 | # %% 20 | import pandas as pd 21 | def get_timestamp(stamps): 22 | return (stamps - pd.Timestamp("1970-01-01")) // pd.Timedelta("1s") 23 | # %% 24 | import numpy as np 25 | from sklearn.metrics import auc 26 | def roc_auc(label_time, pred, negative_sample, sigma): 27 | negative_sample = np.sort(negative_sample)[::-1] 28 | thresholds = list(negative_sample[::int(len(negative_sample)/50)]) 29 | thresholds.append(negative_sample[-1]) 30 | tps=[] 31 | fps=[] 32 | 33 | for thre in thresholds: 34 | pred_pos = pred[pred>thre] 35 | 36 | tp = 0 37 | for i in range(len(label_time)): 38 | start_time = label_time[i] - pd.Timedelta(30, unit='min') 39 | end_time = label_time[i] + pd.Timedelta(30, unit='min') 40 | 41 | detected_event = pred_pos[str(start_time): str(end_time)] 42 | if len(detected_event)>0: 43 | timestamps = get_timestamp(detected_event.index) 44 | delta_t = np.min(np.abs(timestamps.values - get_timestamp(label_time[i]))) 45 | tp += np.exp(-np.power(delta_t/sigma,2)) 46 | tp = tp/len(label_time) 47 | tps.append(tp) 48 | 49 | fp = (negative_sample>thre).sum()/len(negative_sample) 50 | fps.append(fp) 51 | return auc(fps,tps), (fps,tps) 52 | # %% 53 | def roc_auc_all(loss_np, delta_t, sigma): 54 | 55 | ground_truth = np.exp(-np.power((delta_t.values)/sigma,2)) 56 | 57 | loss_sort = np.sort(loss_np)[::-1] 58 | thresholds = list(loss_sort[::int(len(loss_sort)/50)]) 59 | thresholds.append(loss_sort[-1]) 60 | 61 | n_pos = ground_truth.sum() 62 | n_neg = (1-ground_truth).sum() 63 | tps = [] 64 | fps = [] 65 | for thre in thresholds: 66 | pred_pos = loss_np>thre 67 | 68 | tp = ground_truth[pred_pos].sum()/n_pos 69 | fp = (1-ground_truth[pred_pos]).sum()/n_neg 70 | tps.append(tp) 71 | fps.append(fp) 72 | 73 | auc_score = auc(fps, tps) 74 | return auc_score, fps, tps --------------------------------------------------------------------------------