├── temporal_model └── emb_128.pth ├── requirement.txt ├── utils.py ├── logging_set.py ├── LICENSE ├── config.py ├── temporal.py ├── README.md ├── preprocess ├── preprocess_utils.py ├── preprocess_cd.py └── preprocess_porto.py ├── train_labels.py ├── train.py ├── generate_outliers.py ├── mst_oatd.py ├── train_update.py └── mst_oatd_trainer.py /temporal_model/emb_128.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chwang0721/MST-OATD/HEAD/temporal_model/emb_128.pth -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.4 2 | pandas==2.1.4 3 | scipy==1.8.1 4 | scikit-learn==1.3.2 5 | networkx==3.2.1 6 | geopy==2.4.1 7 | torch==2.1.0 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import average_precision_score 3 | 4 | 5 | def make_len_mask(inp): 6 | return inp == 0 7 | 8 | 9 | def make_mask(mask): 10 | return (~mask).detach().type(torch.uint8) 11 | 12 | 13 | def auc_score(y_true, y_score): 14 | # precision, recall, _ = precision_recall_curve(y_true, y_score) 15 | # return auc(recall, precision) 16 | return average_precision_score(y_true, y_score) 17 | 18 | -------------------------------------------------------------------------------- /logging_set.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(filename, verbosity=1, name=None): 5 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 6 | formatter = logging.Formatter( 7 | "[%(asctime)s][%(levelname)s] %(message)s" 8 | ) 9 | logger = logging.getLogger(name) 10 | logger.setLevel(level_dict[verbosity]) 11 | 12 | fh = logging.FileHandler(filename, "a") 13 | fh.setFormatter(formatter) 14 | logger.addHandler(fh) 15 | 16 | sh = logging.StreamHandler() 17 | sh.setFormatter(formatter) 18 | logger.addHandler(sh) 19 | 20 | return logger 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Chenhao Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument("--batch_size", type=int, default=1600) 5 | 6 | parser.add_argument('--embedding_size', type=int, default=128) 7 | parser.add_argument('--hidden_size', type=int, default=512) 8 | parser.add_argument('--n_cluster', type=int, default=20) 9 | 10 | parser.add_argument('--pretrain_lr_s', type=float, default=2e-3) 11 | parser.add_argument('--pretrain_lr_t', type=float, default=2e-3) 12 | 13 | parser.add_argument('--lr_s', type=float, default=2e-4) 14 | parser.add_argument('--lr_t', type=float, default=8e-5) 15 | 16 | parser.add_argument('--epochs', type=int, default=5) 17 | parser.add_argument('--pretrain_epochs', type=int, default=6) 18 | 19 | parser.add_argument("--ratio", type=float, default=0.05, help="ratio of outliers") 20 | parser.add_argument("--distance", type=int, default=2) 21 | parser.add_argument("--fraction", type=float, default=0.2) 22 | parser.add_argument("--obeserved_ratio", type=float, default=1.0) 23 | 24 | parser.add_argument("--device", type=str, default='cuda:0') 25 | parser.add_argument("--dataset", type=str, default='porto') 26 | parser.add_argument("--update_mode", type=str, default='pretrain') 27 | 28 | parser.add_argument("--train_num", type=int, default=80000) # 80000 200000 29 | 30 | parser.add_argument("--s1_size", type=int, default=2) 31 | parser.add_argument("--s2_size", type=int, default=4) 32 | 33 | parser.add_argument("--task", type=str, default='train') 34 | 35 | args = parser.parse_args() 36 | 37 | # python train.py --dataset porto --batch_size 1600 --pretrain_epochs 6 38 | # python train.py --dataset cd --batch_size 300 --pretrain_epochs 3 --epochs 4 39 | -------------------------------------------------------------------------------- /temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class TemporalEmbedding: 5 | def __init__(self, device, model_path="./temporal_model/emb_128.pth"): 6 | self.model = TemporalVec().to(device) 7 | self.model_path = model_path 8 | self.device = device 9 | 10 | def __call__(self, x): 11 | self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) 12 | with torch.no_grad(): 13 | return self.model.encode(torch.Tensor(x).unsqueeze(0)).squeeze(0) 14 | 15 | class TemporalVec(nn.Module): 16 | def __init__(self, k=128, act="sin"): 17 | super(TemporalVec, self).__init__() 18 | 19 | if k % 2 == 0: 20 | k1 = k // 2 21 | k2 = k // 2 22 | else: 23 | k1 = k // 2 24 | k2 = k // 2 + 1 25 | 26 | self.fc1 = nn.Linear(6, k1) 27 | self.fc2 = nn.Linear(6, k2) 28 | self.d2 = nn.Dropout(0.3) 29 | 30 | if act == 'sin': 31 | self.activation = torch.sin 32 | else: 33 | self.activation = torch.cos 34 | 35 | self.fc3 = nn.Linear(k, k // 2) 36 | self.d3 = nn.Dropout(0.3) 37 | self.fc4 = nn.Linear(k // 2, 6) 38 | self.fc5 = torch.nn.Linear(6, 6) 39 | 40 | def forward(self, x): 41 | out1 = self.fc1(x) 42 | out2 = self.d2(self.activation(self.fc2(x))) 43 | out = torch.cat([out1, out2], 1) 44 | out = self.d3(self.fc3(out)) 45 | out = self.fc4(out) 46 | out = self.fc5(out) 47 | return out 48 | 49 | def encode(self, x): 50 | out1 = self.fc1(x) 51 | out2 = self.activation(self.fc2(x)) 52 | out = torch.cat([out1, out2], -1) 53 | return out 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MST-OATD 2 | Code for "Multi-Scale Detection of Anomalous Spatio-Temporal Trajectories in Evolving Trajectory Datasets" 3 | ### Requirements 4 | ``` 5 | pip install -r requirements.txt 6 | ``` 7 | ### Preprocessing 8 | - Step1: Download the Porto dataset (train.csv.zip) from [Porto](https://www.kaggle.com/c/pkdd-15-predict-taxi-service-trajectory-i/data), and the Chengdu dataset (Chengdu.zip) from [Chengdu](https://www.dropbox.com/scl/fi/w4jylj9het6x93btxud6o/Chengdu.zip?rlkey=w6x00pzyjk4z7fvxwhkryeq1l&dl=0). 9 | - Step2: Put the Porto data file in ../datasets/porto/, and unzip it as porto.csv. Put the unzipped Chengdu data in ../datasets/chengdu/. 10 | - Step3: Run preprocessing by 11 | ``` 12 | mkdir -p data/ 13 | cd preprocess 14 | python preprocess_.py 15 | cd .. 16 | mkdir logs models probs 17 | ``` 18 | : porto or cd 19 | 20 | ### Generating ground truth 21 | 22 | ``` 23 | python generate_outliers.py --distance 2 --fraction 0.2 --obeserved_ratio 1.0 --dataset 24 | ``` 25 | distance is used to control the moving distance of outliers, fraction is the fraction of continuous outlier, obeserved_ratio is the ratio of the obeserved part of a trajectory. 26 | ### Training and testing 27 | ``` 28 | python train.py --task train --dataset 29 | python train.py --task test --distance 2 --fraction 0.2 --obeserved_ratio 1.0 --dataset 30 | ``` 31 | ### Training on evolving datasets 32 | ``` 33 | python train_labels.py --dataset 34 | python train_update.py --update_mode pretrain --dataset --train_num 35 | ``` 36 | update_mode contains three modes: pretrain, temporal, rank, is the number of trajectories used for evolving training. 37 | 38 | ### Citation 39 | Please kindly cite our work if you find our paper or codes helpful. 40 | 41 | link: https://dl.acm.org/doi/10.1145/3637528.3671874 42 | ``` 43 | @inproceedings{wang2024multi, 44 | title={Multi-Scale Detection of Anomalous Spatio-Temporal Trajectories in Evolving Trajectory Datasets}, 45 | author={Wang, Chenhao and Chen, Lisi and Shang, Shuo and Jensen, Christian S and Kalnis, Panos}, 46 | booktitle={Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 47 | pages={2980--2990}, 48 | year={2024} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /preprocess/preprocess_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | 4 | import networkx as nx 5 | import numpy as np 6 | import scipy.sparse as sparse 7 | from geopy import distance 8 | 9 | 10 | # Determine whether a point is in boundary 11 | def in_boundary(lat, lng, b): 12 | return b['min_lng'] < lng < b['max_lng'] and b['min_lat'] < lat < b['max_lat'] 13 | 14 | 15 | # Cut long trajectories 16 | def cutting_trajs(traj, longest, shortest): 17 | cutted_trajs = [] 18 | while len(traj) > longest: 19 | random_length = np.random.randint(shortest, longest) 20 | cutted_traj = traj[:random_length] 21 | cutted_trajs.append(cutted_traj) 22 | traj = traj[random_length:] 23 | return cutted_trajs 24 | 25 | 26 | # convert datetime to time vector 27 | def convert_date(str): 28 | timeArray = time.strptime(str, "%Y/%m/%d %H:%M:%S") 29 | t = [timeArray.tm_hour, timeArray.tm_min, timeArray.tm_sec, timeArray.tm_year, timeArray.tm_mon, timeArray.tm_mday] 30 | return t 31 | 32 | 33 | # Calculate timestamp gap 34 | def timestamp_gap(str1, str2): 35 | timestamp1 = datetime.datetime.strptime(str1, "%Y/%m/%d %H:%M:%S") 36 | timestamp2 = datetime.datetime.strptime(str2, "%Y/%m/%d %H:%M:%S") 37 | return (timestamp2 - timestamp1).total_seconds() 38 | 39 | 40 | # Map trajectories to grids 41 | def grid_mapping(boundary, grid_size): 42 | lat_dist = distance.distance((boundary['min_lat'], boundary['min_lng']), 43 | (boundary['max_lat'], boundary['min_lng'])).km 44 | lat_size = (boundary['max_lat'] - boundary['min_lat']) / lat_dist * grid_size 45 | 46 | lng_dist = distance.distance((boundary['min_lat'], boundary['min_lng']), 47 | (boundary['min_lat'], boundary['max_lng'])).km 48 | lng_size = (boundary['max_lng'] - boundary['min_lng']) / lng_dist * grid_size 49 | 50 | lat_grid_num = int(lat_dist / grid_size) + 1 51 | lng_grid_num = int(lng_dist / grid_size) + 1 52 | return lat_size, lng_size, lat_grid_num, lng_grid_num 53 | 54 | 55 | # Generate adjacency matrix and normalized degree matrix 56 | def generate_matrix(lat_grid_num, lng_grid_num): 57 | G = nx.grid_2d_graph(lat_grid_num, lng_grid_num, periodic=False) 58 | A = nx.adjacency_matrix(G) 59 | I = sparse.identity(lat_grid_num * lng_grid_num) 60 | D = np.diag(np.sum(A + I, axis=1)) 61 | D = 1 / (np.sqrt(D) + 1e-10) 62 | D[D == 1e10] = 0. 63 | D = sparse.csr_matrix(D) 64 | return A + I, D 65 | -------------------------------------------------------------------------------- /train_labels.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from config import args 8 | 9 | 10 | class LinearRegression(torch.nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | self.linear = torch.nn.Linear(1, 1) 14 | self.activate = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | out = self.activate(self.linear(x)) 18 | return out 19 | 20 | 21 | class Linear_Model: 22 | def __init__(self): 23 | self.learning_rate = 5e-4 24 | self.epoches = 100000 25 | self.loss_function = torch.nn.MSELoss() 26 | self.create_model() 27 | 28 | def create_model(self): 29 | self.model = LinearRegression().to(args.device) 30 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 31 | 32 | def train(self, num, data): 33 | 34 | x = data.unsqueeze(1) 35 | x = (x - torch.mean(x)) / torch.std(x) 36 | y = torch.arange(1, len(x) + 1).unsqueeze(1) / len(x) 37 | 38 | temp = 100 39 | for epoch in range(self.epoches): 40 | prediction = self.model(x.to(args.device)) 41 | loss = self.loss_function(prediction, y.to(args.device)) 42 | 43 | self.optimizer.zero_grad() 44 | loss.backward() 45 | self.optimizer.step() 46 | 47 | if (epoch + 1) % 10000 == 0: 48 | print("Epoch: {} loss: {}".format(epoch + 1, loss.item())) 49 | if (temp - loss) < 1e-7: 50 | break 51 | else: 52 | temp = loss 53 | torch.save(self.model.state_dict(), "probs/linear_{}_{}.pth".format(num, args.dataset)) 54 | 55 | def test(self, num, data): 56 | x = data.unsqueeze(1) 57 | x = (x - torch.mean(x)) / torch.std(x) 58 | 59 | self.model.load_state_dict(torch.load("probs/linear_{}_{}.pth".format(num, args.dataset))) 60 | self.model.eval() 61 | prediction = self.model(x).cpu() 62 | 63 | return (prediction * len(x)).to(dtype=torch.int32).squeeze(1).detach().numpy() 64 | 65 | 66 | if __name__ == '__main__': 67 | pool = Pool(processes=10) 68 | for num in range(args.n_cluster): 69 | data = np.load('probs/probs_{}_{}.npy'.format(num, args.dataset)) 70 | data = torch.Tensor(data) 71 | linear = Linear_Model() 72 | pool.apply_async(linear.train, (num, data)) 73 | pool.close() 74 | pool.join() 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from config import args 6 | from mst_oatd_trainer import train_mst_oatd, MyDataset, seed_torch, collate_fn 7 | 8 | 9 | def main(): 10 | train_trajs = np.load('./data/{}/train_data_init.npy'.format(args.dataset), allow_pickle=True) 11 | test_trajs = np.load('./data/{}/outliers_data_init_{}_{}_{}.npy'.format(args.dataset, args.distance, args.fraction, 12 | args.obeserved_ratio), allow_pickle=True) 13 | outliers_idx = np.load("./data/{}/outliers_idx_init_{}_{}_{}.npy".format(args.dataset, args.distance, args.fraction, 14 | args.obeserved_ratio), allow_pickle=True) 15 | 16 | train_data = MyDataset(train_trajs) 17 | test_data = MyDataset(test_trajs) 18 | 19 | labels = np.zeros(len(test_trajs)) 20 | for i in outliers_idx: 21 | labels[i] = 1 22 | labels = labels 23 | 24 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, 25 | num_workers=8, pin_memory=True) 26 | outliers_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, 27 | num_workers=8, pin_memory=True) 28 | 29 | MST_OATD = train_mst_oatd(s_token_size, t_token_size, labels, train_loader, outliers_loader, args) 30 | 31 | if args.task == 'train': 32 | 33 | MST_OATD.logger.info("Start pretraining!") 34 | 35 | for epoch in range(args.pretrain_epochs): 36 | MST_OATD.pretrain(epoch) 37 | 38 | MST_OATD.train_gmm() 39 | MST_OATD.save_weights_for_MSTOATD() 40 | 41 | MST_OATD.logger.info("Start training!") 42 | MST_OATD.load_mst_oatd() 43 | for epoch in range(args.epochs): 44 | MST_OATD.train(epoch) 45 | 46 | if args.task == 'test': 47 | 48 | MST_OATD.logger.info('Start testing!') 49 | MST_OATD.logger.info("d = {}".format(args.distance) + ", " + chr(945) + " = {}".format(args.fraction) + ", " 50 | + chr(961) + " = {}".format(args.obeserved_ratio)) 51 | 52 | checkpoint = torch.load(MST_OATD.path_checkpoint) 53 | MST_OATD.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s']) 54 | MST_OATD.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t']) 55 | pr_auc = MST_OATD.detection() 56 | pr_auc = "%.4f" % pr_auc 57 | MST_OATD.logger.info("PR_AUC: {}".format(pr_auc)) 58 | 59 | if args.task == 'train': 60 | MST_OATD.train_gmm_update() 61 | z = MST_OATD.get_hidden() 62 | MST_OATD.get_prob(z.cpu()) 63 | 64 | 65 | if __name__ == "__main__": 66 | 67 | if args.dataset == 'porto': 68 | s_token_size = 51 * 119 69 | t_token_size = 5760 70 | 71 | elif args.dataset == 'cd': 72 | s_token_size = 167 * 154 73 | t_token_size = 8640 74 | 75 | main() 76 | -------------------------------------------------------------------------------- /preprocess/preprocess_cd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from multiprocessing import Pool, Manager 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import scipy.sparse as sparse 8 | from sklearn.model_selection import train_test_split 9 | 10 | from preprocess_utils import * 11 | 12 | 13 | def preprocess(file, traj_path, shortest, longest, boundary, lat_size, lng_size, lng_grid_num, convert_date, 14 | timestamp_gap, in_boundary, cutting_trajs, traj_nums, point_nums): 15 | np.random.seed(1234) 16 | # Read and sort trajectories based on id and timestamp 17 | data = pd.read_csv("{}/{}".format(traj_path, file), header=None) 18 | data.columns = ['id', 'lat', 'lon', 'state', 'timestamp'] 19 | data = data.sort_values(by=['id', 'timestamp']) 20 | data = data[data['state'] == 1] 21 | 22 | trajs = [] 23 | traj_seq = [] 24 | valid = True 25 | 26 | pre_point = data.iloc[0] 27 | 28 | # Select trajectories 29 | for point in data.itertuples(): 30 | 31 | if point.id == pre_point.id and timestamp_gap(pre_point.timestamp, point.timestamp) <= 20: 32 | if in_boundary(point.lat, point.lon, boundary): 33 | grid_i = int((point.lat - boundary['min_lat']) / lat_size) 34 | grid_j = int((point.lon - boundary['min_lng']) / lng_size) 35 | traj_seq.append([grid_i * lng_grid_num + grid_j, convert_date(point[5])]) 36 | else: 37 | valid = False 38 | 39 | else: 40 | if valid: 41 | if shortest <= len(traj_seq) <= longest: 42 | trajs.append(traj_seq) 43 | elif len(traj_seq) > longest: 44 | trajs += cutting_trajs(traj_seq, longest, shortest) 45 | 46 | traj_seq = [] 47 | valid = True 48 | pre_point = point 49 | 50 | traj_nums.append(len(trajs)) 51 | point_nums.append(sum([len(traj) for traj in trajs])) 52 | 53 | train_data, test_data = train_test_split(trajs, test_size=0.2, random_state=42) 54 | np.save("../data/cd/train_data_{}.npy".format(file[:8]), np.array(train_data, dtype=object)) 55 | np.save("../data/cd/test_data_{}.npy".format(file[:8]), np.array(test_data, dtype=object)) 56 | 57 | 58 | # Parallel preprocess 59 | def batch_preprocess(path_list): 60 | manager = Manager() 61 | traj_nums = manager.list() 62 | point_nums = manager.list() 63 | pool = Pool(processes=10) 64 | pool.map(partial(preprocess, traj_path=traj_path, shortest=shortest, longest=longest, boundary=boundary, 65 | lat_size=lat_size, lng_size=lng_size, lng_grid_num=lng_grid_num, convert_date=convert_date, 66 | timestamp_gap=timestamp_gap, in_boundary=in_boundary, cutting_trajs=cutting_trajs, 67 | traj_nums=traj_nums, point_nums=point_nums), path_list) 68 | pool.close() 69 | pool.join() 70 | 71 | num_trajs = sum(traj_nums) 72 | num_points = sum(point_nums) 73 | print("Total trajectory num:", num_trajs) 74 | print("Total point num:", num_points) 75 | 76 | 77 | def merge(path_list): 78 | res_train = [] 79 | res_test = [] 80 | 81 | for file in path_list: 82 | 83 | train_trajs = np.load("../data/cd/train_data_{}.npy".format(file[:8]), allow_pickle=True) 84 | test_trajs = np.load("../data/cd/test_data_{}.npy".format(file[:8]), allow_pickle=True) 85 | res_train.append(train_trajs) 86 | res_test.append(test_trajs) 87 | 88 | res_train = np.concatenate(res_train, axis=0) 89 | res_test = np.concatenate(res_test, axis=0) 90 | 91 | return res_train, res_test 92 | 93 | 94 | def main(): 95 | path_list = os.listdir(traj_path) 96 | path_list.sort(key=lambda x: x.split('.')) 97 | path_list = path_list[:10] 98 | 99 | batch_preprocess(path_list) 100 | train_data, test_data = merge(path_list[:3]) 101 | 102 | np.save("../data/cd/train_data_init.npy", np.array(train_data, dtype=object)) 103 | np.save("../data/cd/test_data_init.npy", np.array(test_data, dtype=object)) 104 | 105 | print('Fnished!') 106 | 107 | 108 | if __name__ == "__main__": 109 | traj_path = "../../datasets/chengdu" 110 | 111 | grid_size = 0.1 112 | shortest, longest = 30, 100 113 | boundary = {'min_lat': 30.6, 'max_lat': 30.75, 'min_lng': 104, 'max_lng': 104.16} 114 | 115 | lat_size, lng_size, lat_grid_num, lng_grid_num = grid_mapping(boundary, grid_size) 116 | A, D = generate_matrix(lat_grid_num, lng_grid_num) 117 | 118 | sparse.save_npz('../data/cd/adj.npz', A) 119 | sparse.save_npz('../data/cd/d_norm.npz', D) 120 | 121 | print('Grid size:', (lat_grid_num, lng_grid_num)) 122 | print('----------Preprocessing----------') 123 | main() 124 | -------------------------------------------------------------------------------- /preprocess/preprocess_porto.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import random 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import scipy.sparse as sparse 8 | from sklearn.model_selection import train_test_split 9 | 10 | from preprocess_utils import * 11 | 12 | 13 | def time_convert(timestamp): 14 | return datetime.datetime.fromtimestamp(timestamp) 15 | 16 | 17 | def preprocess(trajectories, traj_num, point_num): 18 | trajs = [] # Preprocessed trajectories 19 | 20 | for traj in trajectories.itertuples(): 21 | 22 | traj_seq = [] 23 | valid = True # Flag to determine whether a trajectory is in boundary 24 | 25 | polyline = json.loads(traj.POLYLINE) 26 | timestamp = traj.TIMESTAMP 27 | 28 | if len(polyline) >= shortest: 29 | for lng, lat in polyline: 30 | 31 | if in_boundary(lat, lng, boundary): 32 | grid_i = int((lat - boundary['min_lat']) / lat_size) 33 | grid_j = int((lng - boundary['min_lng']) / lng_size) 34 | 35 | t = datetime.datetime.fromtimestamp(timestamp) 36 | t = [t.hour, t.minute, t.second, t.year, t.month, t.day] # Time vector 37 | 38 | traj_seq.append([int(grid_i * lng_grid_num + grid_j), t]) 39 | timestamp += 15 # In porto dataset, the sampling rate is 15 40 | 41 | else: 42 | valid = False 43 | break 44 | 45 | # Randomly delete 30% trajectory points to make the sampling rate not fixed 46 | to_delete = set(random.sample(range(len(traj_seq)), int(len(traj_seq) * 0.3))) 47 | traj_seq = [item for index, item in enumerate(traj_seq) if index not in to_delete] 48 | 49 | # Lengths are limited from 20 to 50 50 | if valid: 51 | if len(traj_seq) <= longest: 52 | trajs.append(traj_seq) 53 | else: 54 | trajs += cutting_trajs(traj_seq, longest, shortest) 55 | 56 | traj_num += len(trajs) 57 | 58 | for traj in trajs: 59 | point_num += len(traj) 60 | 61 | return trajs, traj_num, point_num 62 | 63 | 64 | def main(): 65 | # Read csv file 66 | trajectories = pd.read_csv("{}/{}.csv".format(data_dir, data_name), header=0, usecols=['POLYLINE', 'TIMESTAMP']) 67 | trajectories['datetime'] = trajectories['TIMESTAMP'].apply(time_convert) 68 | 69 | # Inititial dataset 70 | start_time = datetime.datetime(2013, 7, 1, 0, 0, 0) 71 | end_time = datetime.datetime(2013, 9, 1, 0, 0, 0) 72 | 73 | traj_num, point_num = 0, 0 74 | 75 | # Select trajectories from start time to end time 76 | trajs = trajectories[(trajectories['datetime'] >= start_time) & (trajectories['datetime'] < end_time)] 77 | preprocessed_trajs, traj_num, point_num = preprocess(trajs, traj_num, point_num) 78 | train_data, test_data = train_test_split(preprocessed_trajs, test_size=0.2, random_state=42) 79 | 80 | np.save("../data/porto/train_data_init.npy", np.array(train_data, dtype=object)) 81 | np.save("../data/porto/test_data_init.npy", np.array(test_data, dtype=object)) 82 | 83 | start_time = datetime.datetime(2013, 9, 1, 0, 0, 0) 84 | 85 | # Evolving dataset 86 | for month in range(1, 11): 87 | end_time = start_time + datetime.timedelta(days=30) 88 | trajs = trajectories[(trajectories['datetime'] >= start_time) & (trajectories['datetime'] < end_time)] 89 | 90 | preprocessed_trajs, traj_num, point_num = preprocess(trajs, traj_num, point_num) 91 | train_data, test_data = train_test_split(preprocessed_trajs, test_size=0.2, random_state=42) 92 | 93 | np.save("../data/porto/train_data_{}.npy".format(month), np.array(train_data, dtype=object)) 94 | np.save("../data/porto/test_data_{}.npy".format(month), np.array(test_data, dtype=object)) 95 | 96 | start_time = end_time 97 | 98 | # Dataset statistics 99 | print("Total trajectory num:", traj_num) 100 | print("Total point num:", point_num) 101 | 102 | print('Fnished!') 103 | 104 | 105 | if __name__ == '__main__': 106 | random.seed(1234) 107 | np.random.seed(1234) 108 | 109 | data_dir = '../../datasets/porto' 110 | data_name = "porto" 111 | 112 | boundary = {'min_lat': 41.140092, 'max_lat': 41.185969, 'min_lng': -8.690261, 'max_lng': -8.549155} 113 | grid_size = 0.1 114 | shortest, longest = 20, 50 115 | 116 | lat_size, lng_size, lat_grid_num, lng_grid_num = grid_mapping(boundary, grid_size) 117 | A, D = generate_matrix(lat_grid_num, lng_grid_num) 118 | 119 | sparse.save_npz('../data/porto/adj.npz', A) 120 | sparse.save_npz('../data/porto/d_norm.npz', D) 121 | 122 | print('Grid size:', (lat_grid_num, lng_grid_num)) 123 | print('----------Preprocessing----------') 124 | main() 125 | -------------------------------------------------------------------------------- /generate_outliers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | from datetime import timedelta 5 | 6 | import numpy as np 7 | 8 | from config import args 9 | 10 | 11 | # Trajectory location offset 12 | def perturb_point(point, level, offset=None): 13 | point, times = point[0], point[1] 14 | x, y = int(point // map_size[1]), int(point % map_size[1]) 15 | 16 | if offset is None: 17 | offset = [[0, 1], [1, 0], [-1, 0], [0, -1], [1, 1], [-1, -1], [-1, 1], [1, -1]] 18 | x_offset, y_offset = offset[np.random.randint(0, len(offset))] 19 | 20 | else: 21 | x_offset, y_offset = offset 22 | 23 | if 0 <= x + x_offset * level < map_size[0] and 0 <= y + y_offset * level < map_size[1]: 24 | x += x_offset * level 25 | y += y_offset * level 26 | 27 | return [int(x * map_size[1] + y), times] 28 | 29 | 30 | def convert(point): 31 | x, y = int(point // map_size[1]), int(point % map_size[1]) 32 | return [x, y] 33 | 34 | 35 | def distance(a, b): 36 | return math.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) 37 | 38 | 39 | def time_calcuate(vec, s): 40 | a = datetime.datetime(vec[3], vec[4], vec[5], vec[0], vec[1], vec[2]) 41 | t = a + timedelta(seconds=s) 42 | return [t.hour, t.minute, t.second, t.year, t.month, t.day] 43 | 44 | 45 | # Trajectory time offset 46 | def perturb_time(traj, st_loc, end_loc, time_offset, interval): 47 | for i in range(st_loc, end_loc): 48 | traj[i][1] = time_calcuate(traj[i][1], int((i - st_loc + 1) * time_offset * interval)) 49 | 50 | for i in range(end_loc, len(traj)): 51 | traj[i][1] = time_calcuate(traj[i][1], int((end_loc - st_loc) * time_offset * interval)) 52 | return traj 53 | 54 | 55 | def perturb_batch(batch_x, level, prob, selected_idx): 56 | noisy_batch_x = [] 57 | 58 | if args.dataset == 'porto': 59 | interval = 15 60 | else: 61 | interval = 10 62 | 63 | for idx, traj in enumerate(batch_x): 64 | 65 | anomaly_len = int(len(traj) * prob) 66 | anomaly_st_loc = np.random.randint(1, len(traj) - anomaly_len - 1) 67 | 68 | if idx in selected_idx: 69 | anomaly_ed_loc = anomaly_st_loc + anomaly_len 70 | 71 | p_traj = traj[:anomaly_st_loc] + [perturb_point(p, level) for p in 72 | traj[anomaly_st_loc:anomaly_ed_loc]] + traj[anomaly_ed_loc:] 73 | 74 | dis = max(distance(convert(traj[anomaly_st_loc][0]), convert(traj[anomaly_ed_loc][0])), 1) 75 | time_offset = (level * 2) / dis 76 | 77 | p_traj = perturb_time(p_traj, anomaly_st_loc, anomaly_ed_loc, time_offset, interval) 78 | 79 | else: 80 | p_traj = traj 81 | 82 | p_traj = p_traj[:int(len(p_traj) * args.obeserved_ratio)] 83 | noisy_batch_x.append(p_traj) 84 | 85 | return noisy_batch_x 86 | 87 | 88 | def generate_outliers(trajs, ratio=args.ratio, level=args.distance, point_prob=args.fraction): 89 | traj_num = len(trajs) 90 | selected_idx = np.random.randint(0, traj_num, size=int(traj_num * ratio)) 91 | new_trajs = perturb_batch(trajs, level, point_prob, selected_idx) 92 | return new_trajs, selected_idx 93 | 94 | 95 | if __name__ == '__main__': 96 | np.random.seed(1234) 97 | print("=========================") 98 | print("Dataset: " + args.dataset) 99 | print("d = {}".format(args.distance) + ", " + chr(945) + " = {}".format(args.fraction) + ", " 100 | + chr(961) + " = {}".format(args.obeserved_ratio)) 101 | 102 | if args.dataset == 'porto': 103 | map_size = (51, 119) 104 | elif args.dataset == 'cd': 105 | map_size = (167, 154) 106 | 107 | data = np.load("./data/{}/test_data_init.npy".format(args.dataset), allow_pickle=True) 108 | outliers_trajs, outliers_idx = generate_outliers(data) 109 | outliers_trajs = np.array(outliers_trajs, dtype=object) 110 | outliers_idx = np.array(outliers_idx) 111 | 112 | np.save("./data/{}/outliers_data_init_{}_{}_{}.npy".format(args.dataset, args.distance, args.fraction, 113 | args.obeserved_ratio), outliers_trajs) 114 | np.save("./data/{}/outliers_idx_init_{}_{}_{}.npy".format(args.dataset, args.distance, args.fraction, 115 | args.obeserved_ratio), outliers_idx) 116 | 117 | if args.dataset == 'cd': 118 | 119 | traj_path = "../datasets/chengdu" 120 | path_list = os.listdir(traj_path) 121 | path_list.sort(key=lambda x: x.split('.')) 122 | 123 | for file in path_list[3: 10]: 124 | if file[-4:] == '.txt': 125 | data = np.load("./data/{}/test_data_{}.npy".format(args.dataset, file[:8]), 126 | allow_pickle=True) 127 | outliers_trajs, outliers_idx = generate_outliers(data) 128 | outliers_trajs = np.array(outliers_trajs, dtype=object) 129 | outliers_idx = np.array(outliers_idx) 130 | 131 | np.save("./data/{}/outliers_data_{}_{}_{}_{}.npy".format(args.dataset, file[:8], args.distance, 132 | args.fraction, args.obeserved_ratio), outliers_trajs) 133 | np.save("./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, file[:8], args.distance, 134 | args.fraction, args.obeserved_ratio), outliers_idx) 135 | 136 | if args.dataset == 'porto': 137 | for i in range(1, 11): 138 | data = np.load("./data/{}/test_data_{}.npy".format(args.dataset, i), allow_pickle=True) 139 | outliers_trajs, outliers_idx = generate_outliers(data) 140 | outliers_trajs = np.array(outliers_trajs, dtype=object) 141 | outliers_idx = np.array(outliers_idx) 142 | 143 | np.save("./data/{}/outliers_data_{}_{}_{}_{}.npy".format(args.dataset, i, args.distance, 144 | args.fraction, args.obeserved_ratio), outliers_trajs) 145 | np.save("./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, i, args.distance, 146 | args.fraction, args.obeserved_ratio), outliers_idx) 147 | -------------------------------------------------------------------------------- /mst_oatd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | 8 | from temporal import TemporalEmbedding 9 | 10 | 11 | class co_attention(nn.Module): 12 | def __init__(self, dim): 13 | super(co_attention, self).__init__() 14 | 15 | self.Wq_s = nn.Linear(dim, dim, bias=False) 16 | self.Wk_s = nn.Linear(dim, dim, bias=False) 17 | self.Wv_s = nn.Linear(dim, dim, bias=False) 18 | 19 | self.Wq_t = nn.Linear(dim, dim, bias=False) 20 | self.Wk_t = nn.Linear(dim, dim, bias=False) 21 | self.Wv_t = nn.Linear(dim, dim, bias=False) 22 | 23 | self.dim_k = dim ** 0.5 24 | 25 | self.FFN_s = nn.Sequential( 26 | nn.Linear(dim, dim), 27 | nn.ReLU(), 28 | nn.Linear(dim, dim), 29 | nn.Dropout(0.1) 30 | ) 31 | 32 | self.FFN_t = nn.Sequential( 33 | nn.Linear(dim, dim), 34 | nn.ReLU(), 35 | nn.Linear(dim, dim), 36 | nn.Dropout(0.1) 37 | ) 38 | 39 | self.layer_norm = nn.LayerNorm(dim, eps=1e-6) 40 | 41 | def forward(self, seq_s, seq_t): 42 | seq_t = seq_t.unsqueeze(2) 43 | seq_s = seq_s.unsqueeze(2) 44 | 45 | q_s, k_s, v_s = self.Wq_s(seq_t), self.Wk_s(seq_s), self.Wv_s(seq_s) 46 | q_t, k_t, v_t = self.Wq_t(seq_s), self.Wk_t(seq_t), self.Wv_t(seq_t) 47 | 48 | coatt_s = F.softmax(torch.matmul(q_s / self.dim_k, k_s.transpose(2, 3)), dim=-1) 49 | coatt_t = F.softmax(torch.matmul(q_t / self.dim_k, k_t.transpose(2, 3)), dim=-1) 50 | 51 | att_s = self.layer_norm(self.FFN_s(torch.matmul(coatt_s, v_s)) + torch.matmul(coatt_s, v_s)) 52 | att_t = self.layer_norm(self.FFN_t(torch.matmul(coatt_t, v_t)) + torch.matmul(coatt_t, v_t)) 53 | 54 | return att_s.squeeze(2), att_t.squeeze(2) 55 | 56 | 57 | class state_attention(nn.Module): 58 | def __init__(self, args): 59 | super(state_attention, self).__init__() 60 | 61 | self.w_omega = nn.Parameter(torch.Tensor(args.hidden_size, args.hidden_size)) 62 | self.u_omega = nn.Parameter(torch.Tensor(args.hidden_size, 1)) 63 | nn.init.uniform_(self.w_omega, -0.1, 0.1) 64 | nn.init.uniform_(self.u_omega, -0.1, 0.1) 65 | 66 | def forward(self, seq): 67 | u = torch.tanh(torch.matmul(seq, self.w_omega)) 68 | att = torch.matmul(u, self.u_omega).squeeze() 69 | att_score = F.softmax(att, dim=1).unsqueeze(2) 70 | scored_outputs = seq * att_score 71 | return scored_outputs.sum(1) 72 | 73 | 74 | class MST_OATD(nn.Module): 75 | def __init__(self, token_size, token_size_out, args): 76 | super(MST_OATD, self).__init__() 77 | 78 | self.emb_size = args.embedding_size 79 | self.device = args.device 80 | self.n_cluster = args.n_cluster 81 | self.dataset = args.dataset 82 | self.s1_size = args.s1_size 83 | self.s2_size = args.s2_size 84 | 85 | self.pi_prior = nn.Parameter(torch.ones(args.n_cluster) / args.n_cluster) 86 | self.mu_prior = nn.Parameter(torch.randn(args.n_cluster, args.hidden_size)) 87 | self.log_var_prior = nn.Parameter(torch.zeros(args.n_cluster, args.hidden_size)) 88 | 89 | self.embedding = nn.Embedding(token_size, args.embedding_size) 90 | 91 | self.encoder_s1 = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True) 92 | self.encoder_s2 = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True) 93 | self.encoder_s3 = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True) 94 | 95 | self.decoder = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True) 96 | 97 | self.fc_mu = nn.Linear(args.hidden_size, args.hidden_size) 98 | self.fc_logvar = nn.Linear(args.hidden_size, args.hidden_size) 99 | self.layer_norm = nn.LayerNorm(args.hidden_size) 100 | 101 | self.fc_out = nn.Linear(args.hidden_size, token_size_out) 102 | 103 | self.nodes = torch.arange(token_size, dtype=torch.long).to(args.device) 104 | self.adj = sparse.load_npz("data/{}/adj.npz".format(args.dataset)) 105 | self.d_norm = sparse.load_npz("data/{}/d_norm.npz".format(args.dataset)) 106 | 107 | if args.dataset == 'porto': 108 | self.V = nn.Parameter(torch.Tensor(token_size, token_size)) 109 | else: 110 | self.V = nn.Parameter(torch.Tensor(args.embedding_size, args.embedding_size)) 111 | 112 | self.W1 = nn.Parameter(torch.ones(1) / 3) 113 | self.W2 = nn.Parameter(torch.ones(1) / 3) 114 | self.W3 = nn.Parameter(torch.ones(1) / 3) 115 | 116 | self.co_attention = co_attention(args.embedding_size).to(args.device) 117 | self.d2v = TemporalEmbedding(args.device) 118 | 119 | self.w_omega = nn.Parameter(torch.Tensor(args.embedding_size * 2, args.embedding_size * 2)) 120 | self.u_omega = nn.Parameter(torch.Tensor(args.embedding_size * 2, 1)) 121 | 122 | nn.init.uniform_(self.V, -0.2, 0.2) 123 | nn.init.uniform_(self.w_omega, -0.1, 0.1) 124 | nn.init.uniform_(self.u_omega, -0.1, 0.1) 125 | 126 | self.state_att = state_attention(args) 127 | self.dataset = args.dataset 128 | 129 | def scale_process(self, e_inputs, scale_size, lengths): 130 | e_inputs_split = torch.mean(e_inputs.unfold(1, scale_size, scale_size), dim=3) 131 | e_inputs_split = self.attention_layer(e_inputs_split, lengths) 132 | e_inputs_split = pack_padded_sequence(e_inputs_split, lengths, batch_first=True, enforce_sorted=False) 133 | return e_inputs_split 134 | 135 | def Norm_A(self, A, D): 136 | return D.mm(A).mm(self.V).mm(D) 137 | 138 | def Norm_A_N(self, A, D): 139 | return D.mm(A).mm(D) 140 | 141 | def reparameterize(self, mu, log_var): 142 | std = torch.exp(log_var * 0.5) 143 | eps = torch.randn_like(std) 144 | return mu + eps * std 145 | 146 | def padding_mask(self, inp): 147 | return inp == 0 148 | 149 | def attention_layer(self, e_input, lengths): 150 | mask = self.getMask(lengths) 151 | u = torch.tanh(torch.matmul(e_input, self.w_omega)) 152 | att = torch.matmul(u, self.u_omega).squeeze() 153 | att = att.masked_fill(mask == 0, -1e10) 154 | att_score = F.softmax(att, dim=1).unsqueeze(2) 155 | att_e_input = e_input * att_score 156 | return att_e_input 157 | 158 | def array2sparse(self, A): 159 | A = A.tocoo() 160 | values = A.data 161 | indices = np.vstack((A.row, A.col)) 162 | i = torch.LongTensor(indices).to(self.device) 163 | v = torch.FloatTensor(values).to(self.device) 164 | A = torch.sparse_coo_tensor(i, v, torch.Size(A.shape), dtype=torch.float32) 165 | return A 166 | 167 | def getMask(self, seq_lengths): 168 | max_len = max(seq_lengths) 169 | mask = torch.ones((len(seq_lengths), max_len)).to(self.device) 170 | 171 | for i, l in enumerate(seq_lengths): 172 | if l < max_len: 173 | mask[i, l:] = 0 174 | return mask 175 | 176 | def forward(self, trajs, times, lengths, batch_size, mode, c): 177 | 178 | # spatial embedding 179 | adj = self.array2sparse(self.adj) 180 | d_norm = self.array2sparse(self.d_norm) 181 | if self.dataset == 'porto': 182 | H = self.Norm_A(adj, d_norm) 183 | nodes = H.mm(self.embedding(self.nodes)) 184 | else: 185 | H = self.Norm_A_N(adj, d_norm) 186 | nodes = H.mm(self.embedding(self.nodes)).mm(self.V) 187 | 188 | s_inputs = torch.index_select(nodes, 0, trajs.flatten().to(self.device)). \ 189 | reshape(batch_size, -1, self.emb_size) 190 | 191 | # temporal embedding 192 | t_inputs = self.d2v(times.to(self.device)).to(self.device) 193 | 194 | att_s, att_t = self.co_attention(s_inputs, t_inputs) 195 | st_inputs = torch.concat((att_s, att_t), dim=2) 196 | d_inputs = torch.cat((torch.zeros(batch_size, 1, self.emb_size * 2, dtype=torch.long).to(self.device), 197 | st_inputs[:, :-1, :]), dim=1) # [bs, seq_len, emb_size * 2] 198 | 199 | decoder_inputs = pack_padded_sequence(d_inputs, lengths, batch_first=True, enforce_sorted=False) 200 | 201 | if mode == "pretrain" or "train": 202 | encoder_inputs_s1 = pack_padded_sequence(self.attention_layer(st_inputs, lengths), lengths, 203 | batch_first=True, enforce_sorted=False) 204 | encoder_inputs_s2 = self.scale_process(st_inputs, self.s1_size, [int(i // self.s1_size) for i in lengths]) 205 | encoder_inputs_s3 = self.scale_process(st_inputs, self.s2_size, [int(i // self.s2_size) for i in lengths]) 206 | 207 | _, encoder_final_state_s1 = self.encoder_s1(encoder_inputs_s1) 208 | _, encoder_final_state_s2 = self.encoder_s2(encoder_inputs_s2) 209 | _, encoder_final_state_s3 = self.encoder_s3(encoder_inputs_s3) 210 | 211 | encoder_final_state = (self.W1 * encoder_final_state_s1 + self.W2 * encoder_final_state_s2 212 | + self.W3 * encoder_final_state_s3) 213 | sum_W = self.W1.data + self.W2.data + self.W3.data 214 | self.W1.data /= sum_W 215 | self.W2.data /= sum_W 216 | self.W3.data /= sum_W 217 | 218 | mu = self.fc_mu(encoder_final_state) 219 | logvar = self.fc_logvar(encoder_final_state) 220 | z = self.reparameterize(mu, logvar) 221 | 222 | decoder_outputs, _ = self.decoder(decoder_inputs, z) 223 | decoder_outputs, _ = pad_packed_sequence(decoder_outputs, batch_first=True) 224 | 225 | elif mode == "test": 226 | mu = torch.stack([self.mu_prior] * batch_size, dim=1)[c: c + 1] 227 | decoder_outputs, _ = self.decoder(decoder_inputs, mu) 228 | decoder_outputs, _ = pad_packed_sequence(decoder_outputs, batch_first=True) 229 | logvar, z = None, None 230 | 231 | output = self.fc_out(self.layer_norm(decoder_outputs)) 232 | 233 | return output, mu, logvar, z 234 | -------------------------------------------------------------------------------- /train_update.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from sklearn.mixture import GaussianMixture 7 | from torch.utils.data import DataLoader 8 | 9 | from config import args 10 | from mst_oatd_trainer import train_mst_oatd, MyDataset, seed_torch, collate_fn 11 | from train_labels import Linear_Model 12 | 13 | 14 | def get_z(trajs): 15 | data = MyDataset(trajs) 16 | loader = DataLoader(dataset=data, batch_size=args.batch_size, shuffle=False, 17 | collate_fn=collate_fn, num_workers=4) 18 | MST_OATD_U.train_loader = loader 19 | return MST_OATD_U.get_hidden().cpu() 20 | 21 | 22 | def load_gmm(): 23 | checkpoint = torch.load('./models/gmm_update_{}.pt'.format(args.dataset)) 24 | gmm = GaussianMixture(n_components=args.n_cluster, covariance_type='diag') 25 | gmm.weights_ = checkpoint['gmm_update_weights'] 26 | gmm.means_ = checkpoint['gmm_update_means'] 27 | gmm.covariances_ = checkpoint['gmm_update_covariances'] 28 | gmm.precisions_cholesky_ = checkpoint['gmm_update_precisions_cholesky'] 29 | return gmm 30 | 31 | 32 | def get_index(trajs, cats_sample): 33 | z = get_z(trajs) 34 | cats = gmm.predict(z) 35 | index = [cat in cats_sample for cat in cats] 36 | trajs = trajs[index] 37 | return index, z[index], trajs 38 | 39 | 40 | def get_score(z): 41 | probs = gmm.predict_proba(z) 42 | 43 | idx = [] 44 | linear = Linear_Model() 45 | for label in range(args.n_cluster): 46 | data = -probs[:, label] 47 | rank = linear.test(label, torch.Tensor(data).to(args.device)) 48 | idx.append(rank) 49 | idx = np.array(idx).T 50 | idxs = np.argsort(idx, axis=1) 51 | 52 | return idxs 53 | 54 | 55 | def update_data(origin_trajs, train_trajs, cats_sample): 56 | _, z, train_trajs = get_index(train_trajs, cats_sample) 57 | idxs = get_score(z) 58 | 59 | max_idxs = idxs[:, 0] 60 | for i, traj in enumerate(train_trajs): 61 | max_idx = max_idxs[i] 62 | origin_trajs[max_idx].append(traj) 63 | 64 | min_c = args.n_cluster - 1 65 | min_idx = idxs[:, min_c][i] 66 | 67 | while not origin_trajs[min_idx]: 68 | min_c -= 1 69 | min_idx = idxs[:, min_c][i] 70 | origin_trajs[min_idx].pop(0) 71 | 72 | return np.array(sum(origin_trajs, []), dtype=object) 73 | 74 | 75 | def train_update(train_trajs, test_trajs, labels, i): 76 | train_data = MyDataset(train_trajs) 77 | test_data = MyDataset(test_trajs) 78 | 79 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, 80 | collate_fn=collate_fn, num_workers=8) 81 | outliers_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False, 82 | collate_fn=collate_fn, num_workers=8) 83 | 84 | MST_OATD.train_loader = train_loader 85 | MST_OATD.outliers_loader = outliers_loader 86 | MST_OATD.labels = labels 87 | 88 | pr_auc = [] 89 | for epoch in range(args.epochs): 90 | MST_OATD.train(epoch) 91 | results = MST_OATD.detection() 92 | pr_auc.append(results) 93 | results = "%.4f" % max(pr_auc) 94 | print("File {} PR_AUC:".format(i), results) 95 | return max(pr_auc) 96 | 97 | 98 | def test_update(test_trajs, labels, i): 99 | test_data = MyDataset(test_trajs) 100 | outliers_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False, 101 | collate_fn=collate_fn, num_workers=8) 102 | 103 | MST_OATD_U.outliers_loader = outliers_loader 104 | MST_OATD_U.labels = labels 105 | 106 | pr_auc = MST_OATD_U.detection() 107 | results = "%.4f" % pr_auc 108 | print("File {} PR_AUC:".format(i), results) 109 | return pr_auc 110 | 111 | 112 | # Get the type of trajectory and save the trajectory data in the n_cluster list. 113 | def get_category(trajs): 114 | z = get_z(trajs) 115 | c_labels = gmm.predict(z) 116 | 117 | origin_trajs = [] 118 | for label in range(args.n_cluster): 119 | index = c_labels == label 120 | origin_trajs.append(trajs[index].tolist()) 121 | return origin_trajs 122 | 123 | 124 | def main(): 125 | random.seed(1234) 126 | train_trajs = np.load('./data/{}/train_data_init.npy'.format(args.dataset), 127 | allow_pickle=True)[-args.train_num:] 128 | 129 | if args.update_mode == 'rank': 130 | random_train_trajs = train_trajs 131 | 132 | all_pr_auc = [] 133 | 134 | if args.dataset == 'porto': 135 | for i in range(1, 11): 136 | train_trajs_new = np.load('./data/{}/train_data_{}.npy'.format(args.dataset, i), 137 | allow_pickle=True) 138 | test_trajs = np.load( 139 | './data/{}/outliers_data_{}_{}_{}_{}.npy'.format(args.dataset, i, args.distance, args.fraction, 140 | args.obeserved_ratio), 141 | allow_pickle=True) 142 | outliers_idx = np.load( 143 | "./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, i, args.distance, args.fraction, 144 | args.obeserved_ratio), 145 | allow_pickle=True) 146 | 147 | labels = np.zeros(len(test_trajs)) 148 | for idx in outliers_idx: 149 | labels[idx] = 1 150 | 151 | cats = list(range(0, args.n_cluster)) 152 | cats_sample = random.sample(cats, args.n_cluster // 4) 153 | test_index, _, _ = get_index(test_trajs, cats_sample) 154 | 155 | if args.update_mode == 'temporal': 156 | train_index, _, _ = get_index(train_trajs_new, cats_sample) 157 | train_trajs = np.concatenate((train_trajs, train_trajs_new[train_index]))[-len(train_trajs):] 158 | pr_auc = train_update(train_trajs, test_trajs[test_index], labels[test_index], i) 159 | 160 | elif args.update_mode == 'rank': 161 | trajs = get_category(random_train_trajs) 162 | train_trajs = update_data(trajs, train_trajs_new, cats_sample) 163 | random_train_trajs = np.concatenate((random_train_trajs, train_trajs_new))[-len(train_trajs):] 164 | pr_auc = train_update(train_trajs, test_trajs[test_index], labels[test_index], i) 165 | 166 | elif args.update_mode == 'pretrain': 167 | pr_auc = test_update(test_trajs[test_index], labels[test_index], i) 168 | 169 | all_pr_auc.append(pr_auc) 170 | 171 | if args.dataset == 'cd': 172 | traj_path = "../datasets/chengdu" 173 | path_list = os.listdir(traj_path) 174 | path_list.sort(key=lambda x: x.split('.')) 175 | path_list = path_list[3: 10] 176 | 177 | for i in range(len(path_list)): 178 | train_trajs_new = np.load('./data/{}/train_data_{}.npy'.format(args.dataset, path_list[i][:8]), 179 | allow_pickle=True) 180 | print(len(train_trajs_new)) 181 | test_trajs = np.load( 182 | './data/{}/outliers_data_{}_{}_{}_{}.npy'.format(args.dataset, path_list[i][:8], args.distance, 183 | args.fraction, args.obeserved_ratio), 184 | allow_pickle=True) 185 | outliers_idx = np.load( 186 | "./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, path_list[i][:8], args.distance, 187 | args.fraction, args.obeserved_ratio), 188 | allow_pickle=True) 189 | 190 | labels = np.zeros(len(test_trajs)) 191 | for idx in outliers_idx: 192 | labels[idx] = 1 193 | 194 | cats = list(range(0, args.n_cluster)) 195 | cats_sample = random.sample(cats, args.n_cluster // 4) 196 | test_index, _, _ = get_index(test_trajs, cats_sample) 197 | 198 | test_trajs = test_trajs[test_index] 199 | labels = labels[test_index] 200 | 201 | if args.update_mode == 'temporal': 202 | 203 | train_index, _, _ = get_index(train_trajs_new, cats_sample) 204 | train_trajs_new = train_trajs_new[train_index] 205 | 206 | train_trajs = train_trajs[-(len(train_trajs) - len(train_trajs_new)):] 207 | train_trajs = np.concatenate((train_trajs, train_trajs_new)) 208 | print('Trajecotory num:', len(train_trajs)) 209 | pr_auc = train_update(train_trajs, test_trajs, labels, i) 210 | 211 | elif args.update_mode == 'rank': 212 | 213 | trajs = get_category(random_train_trajs) 214 | train_trajs = update_data(trajs, train_trajs_new, cats_sample) 215 | 216 | random_train_trajs = random_train_trajs[-(len(random_train_trajs) - len(train_trajs_new)):] 217 | random_train_trajs = np.concatenate((random_train_trajs, train_trajs_new)) 218 | print('Trajecotory num:', len(train_trajs)) 219 | pr_auc = train_update(train_trajs, test_trajs, labels, i) 220 | 221 | elif args.update_mode == 'pretrain': 222 | pr_auc = test_update(test_trajs, labels, i) 223 | 224 | all_pr_auc.append(pr_auc) 225 | print('------------------------') 226 | results = "%.4f" % (sum(all_pr_auc) / len(all_pr_auc)) 227 | print('Average PR_AUC:', results) 228 | print('------------------------') 229 | 230 | 231 | if __name__ == "__main__": 232 | seed_torch(1234) 233 | print("===========================") 234 | print("Dataset:", args.dataset) 235 | print("Mode:", args.update_mode) 236 | 237 | if args.dataset == 'porto': 238 | s_token_size = 51 * 119 239 | t_token_size = 5760 240 | elif args.dataset == 'cd': 241 | s_token_size = 167 * 154 242 | t_token_size = 8640 243 | 244 | gmm = load_gmm() 245 | 246 | MST_OATD = train_mst_oatd(s_token_size, t_token_size, None, None, None, args) 247 | 248 | MST_OATD.mode = 'update' 249 | checkpoint = torch.load(MST_OATD.path_checkpoint) 250 | MST_OATD.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s']) 251 | MST_OATD.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t']) 252 | 253 | MST_OATD_U = train_mst_oatd(s_token_size, t_token_size, None, None, None, args) 254 | 255 | checkpoint_U = torch.load(MST_OATD_U.path_checkpoint) 256 | MST_OATD_U.mode = 'update' 257 | MST_OATD_U.MST_OATD_S.load_state_dict(checkpoint_U['model_state_dict_s']) 258 | MST_OATD_U.MST_OATD_T.load_state_dict(checkpoint_U['model_state_dict_t']) 259 | main() 260 | -------------------------------------------------------------------------------- /mst_oatd_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from sklearn.mixture import GaussianMixture 10 | from torch.optim.lr_scheduler import StepLR 11 | from torch.utils.data import Dataset 12 | 13 | from logging_set import get_logger 14 | from mst_oatd import MST_OATD 15 | from utils import auc_score, make_mask, make_len_mask 16 | 17 | 18 | def collate_fn(batch): 19 | max_len = max(len(x) for x in batch) 20 | seq_lengths = list(map(len, batch)) 21 | batch_trajs = [x + [[0, [0] * 6]] * (max_len - len(x)) for x in batch] 22 | return torch.LongTensor(np.array(batch_trajs, dtype=object)[:, :, 0].tolist()), \ 23 | torch.Tensor(np.array(batch_trajs, dtype=object)[:, :, 1].tolist()), np.array(seq_lengths) 24 | 25 | 26 | def seed_torch(seed): 27 | random.seed(seed) 28 | os.environ['PYTHONHASHSEED'] = str(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.backends.cudnn.benchmark = False 33 | torch.backends.cudnn.deterministic = True 34 | 35 | 36 | class MyDataset(Dataset): 37 | def __init__(self, seqs): 38 | self.seqs = seqs 39 | 40 | def __len__(self): 41 | return len(self.seqs) 42 | 43 | def __getitem__(self, index): 44 | data_seqs = self.seqs[index] 45 | return data_seqs 46 | 47 | 48 | def time_convert(times, time_interval): 49 | return torch.Tensor((times[:, :, 2] + times[:, :, 1] * 60 + times[:, :, 0] * 3600) // time_interval).long() 50 | 51 | 52 | def savecheckpoint(state, file_name): 53 | torch.save(state, file_name) 54 | 55 | 56 | class train_mst_oatd: 57 | def __init__(self, s_token_size, t_token_size, labels, train_loader, outliers_loader, args): 58 | 59 | self.MST_OATD_S = MST_OATD(s_token_size, s_token_size, args).to(args.device) 60 | self.MST_OATD_T = MST_OATD(s_token_size, t_token_size, args).to(args.device) 61 | 62 | self.device = args.device 63 | self.dataset = args.dataset 64 | self.n_cluster = args.n_cluster 65 | self.hidden_size = args.hidden_size 66 | 67 | self.crit = nn.CrossEntropyLoss() 68 | self.detec = nn.CrossEntropyLoss(reduction='none') 69 | 70 | self.pretrain_optimizer_s = optim.AdamW([ 71 | {'params': self.MST_OATD_S.parameters()}, 72 | ], lr=args.pretrain_lr_s) 73 | 74 | self.pretrain_optimizer_t = optim.AdamW([ 75 | {'params': self.MST_OATD_T.parameters()}, 76 | ], lr=args.pretrain_lr_t) 77 | 78 | self.optimizer_s = optim.AdamW([ 79 | {'params': self.MST_OATD_S.parameters()}, 80 | ], lr=args.lr_s) 81 | 82 | self.optimizer_t = optim.Adam([ 83 | {'params': self.MST_OATD_T.parameters()}, 84 | ], lr=args.lr_t) 85 | 86 | self.lr_pretrain_s = StepLR(self.pretrain_optimizer_s, step_size=2, gamma=0.9) 87 | self.lr_pretrain_t = StepLR(self.pretrain_optimizer_t, step_size=2, gamma=0.9) 88 | 89 | self.train_loader = train_loader 90 | self.outliers_loader = outliers_loader 91 | 92 | self.pretrained_path = 'models/pretrain_mstoatd_{}.pth'.format(args.dataset) 93 | self.path_checkpoint = 'models/mstoatd_{}.pth'.format(args.dataset) 94 | self.gmm_path = "models/gmm_{}.pt".format(args.dataset) 95 | self.gmm_update_path = "models/gmm_update_{}.pt".format(args.dataset) 96 | self.logger = get_logger("./logs/{}.log".format(args.dataset)) 97 | 98 | self.labels = labels 99 | if args.dataset == 'cd': 100 | self.time_interval = 10 101 | else: 102 | self.time_interval = 15 103 | self.mode = 'train' 104 | 105 | self.s1_size = args.s1_size 106 | self.s2_size = args.s2_size 107 | 108 | def pretrain(self, epoch): 109 | self.MST_OATD_S.train() 110 | self.MST_OATD_T.train() 111 | epo_loss = 0 112 | 113 | for batch in self.train_loader: 114 | trajs, times, seq_lengths = batch 115 | batch_size = len(trajs) 116 | 117 | mask = make_mask(make_len_mask(trajs)).to(self.device) 118 | 119 | self.pretrain_optimizer_s.zero_grad() 120 | self.pretrain_optimizer_t.zero_grad() 121 | output_s, _, _, _ = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1) 122 | output_t, _, _, _ = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "pretrain", -1) 123 | 124 | times = time_convert(times, self.time_interval) 125 | 126 | loss = self.crit(output_s[mask == 1], trajs.to(self.device)[mask == 1]) 127 | loss += self.crit(output_t[mask == 1], times.to(self.device)[mask == 1]) 128 | 129 | loss.backward() 130 | 131 | self.pretrain_optimizer_s.step() 132 | self.pretrain_optimizer_t.step() 133 | epo_loss += loss.item() 134 | 135 | self.lr_pretrain_s.step() 136 | self.lr_pretrain_t.step() 137 | epo_loss = "%.4f" % (epo_loss / len(self.train_loader)) 138 | self.logger.info("Epoch {} pretrain loss: {}".format(epoch + 1, epo_loss)) 139 | checkpoint = {"model_state_dict_s": self.MST_OATD_S.state_dict(), 140 | "model_state_dict_t": self.MST_OATD_T.state_dict()} 141 | torch.save(checkpoint, self.pretrained_path) 142 | 143 | def get_hidden(self): 144 | checkpoint = torch.load(self.path_checkpoint) 145 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s']) 146 | self.MST_OATD_S.eval() 147 | with torch.no_grad(): 148 | z = [] 149 | for batch in self.train_loader: 150 | trajs, times, seq_lengths = batch 151 | batch_size = len(trajs) 152 | _, _, _, hidden = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1) 153 | z.append(hidden.squeeze(0)) 154 | z = torch.cat(z, dim=0) 155 | return z 156 | 157 | def train_gmm(self): 158 | self.MST_OATD_S.eval() 159 | self.MST_OATD_T.eval() 160 | checkpoint = torch.load(self.pretrained_path) 161 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s']) 162 | self.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t']) 163 | 164 | with torch.no_grad(): 165 | z_s = [] 166 | z_t = [] 167 | for batch in self.train_loader: 168 | trajs, times, seq_lengths = batch 169 | batch_size = len(trajs) 170 | _, _, _, hidden_s = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1) 171 | _, _, _, hidden_t = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "pretrain", -1) 172 | 173 | z_s.append(hidden_s.squeeze(0)) 174 | z_t.append(hidden_t.squeeze(0)) 175 | z_s = torch.cat(z_s, dim=0) 176 | z_t = torch.cat(z_t, dim=0) 177 | 178 | self.logger.info('Start fitting gaussian mixture model!') 179 | 180 | self.gmm_s = GaussianMixture(n_components=self.n_cluster, covariance_type="diag", n_init=1) 181 | self.gmm_s.fit(z_s.cpu().numpy()) 182 | 183 | self.gmm_t = GaussianMixture(n_components=self.n_cluster, covariance_type="diag", n_init=1) 184 | self.gmm_t.fit(z_t.cpu().numpy()) 185 | 186 | def save_weights_for_MSTOATD(self): 187 | savecheckpoint({"gmm_s_mu_prior": self.gmm_s.means_, 188 | "gmm_s_pi_prior": self.gmm_s.weights_, 189 | "gmm_s_logvar_prior": self.gmm_s.covariances_, 190 | "gmm_t_mu_prior": self.gmm_t.means_, 191 | "gmm_t_pi_prior": self.gmm_t.weights_, 192 | "gmms_t_logvar_prior": self.gmm_t.covariances_}, self.gmm_path) 193 | 194 | def train_gmm_update(self): 195 | 196 | checkpoint = torch.load(self.path_checkpoint) 197 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s']) 198 | self.MST_OATD_S.eval() 199 | 200 | with torch.no_grad(): 201 | z = [] 202 | for batch in self.train_loader: 203 | trajs, times, seq_lengths = batch 204 | batch_size = len(trajs) 205 | _, _, _, hidden = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1) 206 | z.append(hidden.squeeze(0)) 207 | z = torch.cat(z, dim=0) 208 | 209 | self.logger.info('Start fitting gaussian mixture model!') 210 | 211 | self.gmm = GaussianMixture(n_components=self.n_cluster, covariance_type="diag", n_init=3) 212 | self.gmm.fit(z.cpu().numpy()) 213 | 214 | savecheckpoint({"gmm_update_weights": self.gmm.weights_, 215 | "gmm_update_means": self.gmm.means_, 216 | "gmm_update_covariances": self.gmm.covariances_, 217 | "gmm_update_precisions_cholesky": self.gmm.precisions_cholesky_}, self.gmm_update_path) 218 | 219 | def train(self, epoch): 220 | self.MST_OATD_S.train() 221 | self.MST_OATD_T.train() 222 | total_loss = 0 223 | for batch in self.train_loader: 224 | trajs, times, seq_lengths = batch 225 | batch_size = len(trajs) 226 | 227 | mask = make_mask(make_len_mask(trajs)).to(self.device) 228 | 229 | self.optimizer_s.zero_grad() 230 | self.optimizer_t.zero_grad() 231 | 232 | x_hat_s, mu_s, log_var_s, z_s = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "train", -1) 233 | loss = self.Loss(x_hat_s, trajs.to(self.device), mu_s.squeeze(0), log_var_s.squeeze(0), 234 | z_s.squeeze(0), 's', mask) 235 | x_hat_t, mu_t, log_var_t, z_t = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "train", -1) 236 | times = time_convert(times, self.time_interval) 237 | loss += self.Loss(x_hat_t, times.to(self.device), mu_t.squeeze(0), log_var_t.squeeze(0), 238 | z_t.squeeze(0), 't', mask) 239 | 240 | loss.backward() 241 | self.optimizer_s.step() 242 | self.optimizer_t.step() 243 | total_loss += loss.item() 244 | 245 | if self.mode == "train": 246 | total_loss = "%.4f" % (total_loss / len(self.train_loader)) 247 | self.logger.info('Epoch {} loss: {}'.format(epoch + 1, total_loss)) 248 | checkpoint = {"model_state_dict_s": self.MST_OATD_S.state_dict(), 249 | "model_state_dict_t": self.MST_OATD_T.state_dict()} 250 | torch.save(checkpoint, self.path_checkpoint) 251 | 252 | def detection(self): 253 | 254 | self.MST_OATD_S.eval() 255 | all_likelihood_s = [] 256 | self.MST_OATD_T.eval() 257 | all_likelihood_t = [] 258 | 259 | with torch.no_grad(): 260 | 261 | for batch in self.outliers_loader: 262 | trajs, times, seq_lengths = batch 263 | batch_size = len(trajs) 264 | mask = make_mask(make_len_mask(trajs)).to(self.device) 265 | times_token = time_convert(times, self.time_interval) 266 | 267 | c_likelihood_s = [] 268 | c_likelihood_t = [] 269 | 270 | for c in range(self.n_cluster): 271 | output_s, _, _, _ = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "test", c) 272 | likelihood_s = - self.detec(output_s.reshape(-1, output_s.shape[-1]), 273 | trajs.to(self.device).reshape(-1)) 274 | likelihood_s = torch.exp( 275 | torch.sum(mask * (likelihood_s.reshape(batch_size, -1)), dim=-1) / torch.sum(mask, 1)) 276 | 277 | output_t, _, _, _ = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "test", c) 278 | likelihood_t = - self.detec(output_t.reshape(-1, output_t.shape[-1]), 279 | times_token.to(self.device).reshape(-1)) 280 | likelihood_t = torch.exp( 281 | torch.sum(mask * (likelihood_t.reshape(batch_size, -1)), dim=-1) / torch.sum(mask, 1)) 282 | 283 | c_likelihood_s.append(likelihood_s.unsqueeze(0)) 284 | c_likelihood_t.append(likelihood_t.unsqueeze(0)) 285 | 286 | all_likelihood_s.append(torch.cat(c_likelihood_s).max(0)[0]) 287 | all_likelihood_t.append(torch.cat(c_likelihood_t).max(0)[0]) 288 | 289 | likelihood_s = torch.cat(all_likelihood_s, dim=0) 290 | likelihood_t = torch.cat(all_likelihood_t, dim=0) 291 | 292 | pr_auc = auc_score(self.labels, (1 - likelihood_s * likelihood_t).cpu().detach().numpy()) 293 | return pr_auc 294 | 295 | def gaussian_pdf_log(self, x, mu, log_var): 296 | return -0.5 * (torch.sum(np.log(np.pi * 2) + log_var + (x - mu).pow(2) / torch.exp(log_var), 1)) 297 | 298 | def gaussian_pdfs_log(self, x, mus, log_vars): 299 | G = [] 300 | for c in range(self.n_cluster): 301 | G.append(self.gaussian_pdf_log(x, mus[c:c + 1, :], log_vars[c:c + 1, :]).view(-1, 1)) 302 | return torch.cat(G, 1) 303 | 304 | def Loss(self, x_hat, targets, z_mu, z_sigma2_log, z, mode, mask): 305 | if mode == 's': 306 | pi = self.MST_OATD_S.pi_prior 307 | log_sigma2_c = self.MST_OATD_S.log_var_prior 308 | mu_c = self.MST_OATD_S.mu_prior 309 | elif mode == 't': 310 | pi = self.MST_OATD_T.pi_prior 311 | log_sigma2_c = self.MST_OATD_T.log_var_prior 312 | mu_c = self.MST_OATD_T.mu_prior 313 | 314 | reconstruction_loss = self.crit(x_hat[mask == 1], targets[mask == 1]) 315 | 316 | gaussian_loss = torch.mean(torch.mean(self.gaussian_pdf_log(z, z_mu, z_sigma2_log).unsqueeze(1) - 317 | self.gaussian_pdfs_log(z, mu_c, log_sigma2_c), dim=1), dim=-1).mean() 318 | 319 | pi = F.softmax(pi, dim=-1) 320 | z = z.unsqueeze(1) 321 | mu_c = mu_c.unsqueeze(0) 322 | log_sigma2_c = log_sigma2_c.unsqueeze(0) 323 | 324 | logits = - torch.sum(torch.pow(z - mu_c, 2) / torch.exp(log_sigma2_c), dim=-1) 325 | logits = F.softmax(logits, dim=-1) + 1e-10 326 | category_loss = torch.mean(torch.sum(logits * (torch.log(logits) - torch.log(pi).unsqueeze(0)), dim=-1)) 327 | 328 | loss = reconstruction_loss + gaussian_loss / self.hidden_size + category_loss * 0.1 329 | return loss 330 | 331 | def load_mst_oatd(self): 332 | checkpoint = torch.load(self.pretrained_path) 333 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s']) 334 | self.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t']) 335 | 336 | gmm_params = torch.load(self.gmm_path) 337 | 338 | self.MST_OATD_S.pi_prior.data = torch.from_numpy(gmm_params['gmm_s_pi_prior']).to(self.device) 339 | self.MST_OATD_S.mu_prior.data = torch.from_numpy(gmm_params['gmm_s_mu_prior']).to(self.device) 340 | self.MST_OATD_S.log_var_prior.data = torch.from_numpy(gmm_params['gmm_s_logvar_prior']).to(self.device) 341 | 342 | self.MST_OATD_T.pi_prior.data = torch.from_numpy(gmm_params['gmm_t_pi_prior']).to(self.device) 343 | self.MST_OATD_T.mu_prior.data = torch.from_numpy(gmm_params['gmm_t_mu_prior']).to(self.device) 344 | self.MST_OATD_T.log_var_prior.data = torch.from_numpy(gmm_params['gmms_t_logvar_prior']).to(self.device) 345 | 346 | def get_prob(self, z): 347 | gmm = GaussianMixture(n_components=self.n_cluster, covariance_type='diag') 348 | gmm_params = torch.load(self.gmm_update_path) 349 | gmm.precisions_cholesky_ = gmm_params['gmm_update_precisions_cholesky'] 350 | gmm.weights_ = gmm_params['gmm_update_weights'] 351 | gmm.means_ = gmm_params['gmm_update_means'] 352 | gmm.covariances_ = gmm_params['gmm_update_covariances'] 353 | 354 | probs = gmm.predict_proba(z) 355 | 356 | for label in range(self.n_cluster): 357 | np.save('probs/probs_{}_{}.npy'.format(label, self.dataset), np.sort(-probs[:, label])) 358 | --------------------------------------------------------------------------------