├── temporal_model
└── emb_128.pth
├── requirement.txt
├── utils.py
├── logging_set.py
├── LICENSE
├── config.py
├── temporal.py
├── README.md
├── preprocess
├── preprocess_utils.py
├── preprocess_cd.py
└── preprocess_porto.py
├── train_labels.py
├── train.py
├── generate_outliers.py
├── mst_oatd.py
├── train_update.py
└── mst_oatd_trainer.py
/temporal_model/emb_128.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chwang0721/MST-OATD/HEAD/temporal_model/emb_128.pth
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | numpy==1.24.4
2 | pandas==2.1.4
3 | scipy==1.8.1
4 | scikit-learn==1.3.2
5 | networkx==3.2.1
6 | geopy==2.4.1
7 | torch==2.1.0
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sklearn.metrics import average_precision_score
3 |
4 |
5 | def make_len_mask(inp):
6 | return inp == 0
7 |
8 |
9 | def make_mask(mask):
10 | return (~mask).detach().type(torch.uint8)
11 |
12 |
13 | def auc_score(y_true, y_score):
14 | # precision, recall, _ = precision_recall_curve(y_true, y_score)
15 | # return auc(recall, precision)
16 | return average_precision_score(y_true, y_score)
17 |
18 |
--------------------------------------------------------------------------------
/logging_set.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def get_logger(filename, verbosity=1, name=None):
5 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
6 | formatter = logging.Formatter(
7 | "[%(asctime)s][%(levelname)s] %(message)s"
8 | )
9 | logger = logging.getLogger(name)
10 | logger.setLevel(level_dict[verbosity])
11 |
12 | fh = logging.FileHandler(filename, "a")
13 | fh.setFormatter(formatter)
14 | logger.addHandler(fh)
15 |
16 | sh = logging.StreamHandler()
17 | sh.setFormatter(formatter)
18 | logger.addHandler(sh)
19 |
20 | return logger
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Chenhao Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument("--batch_size", type=int, default=1600)
5 |
6 | parser.add_argument('--embedding_size', type=int, default=128)
7 | parser.add_argument('--hidden_size', type=int, default=512)
8 | parser.add_argument('--n_cluster', type=int, default=20)
9 |
10 | parser.add_argument('--pretrain_lr_s', type=float, default=2e-3)
11 | parser.add_argument('--pretrain_lr_t', type=float, default=2e-3)
12 |
13 | parser.add_argument('--lr_s', type=float, default=2e-4)
14 | parser.add_argument('--lr_t', type=float, default=8e-5)
15 |
16 | parser.add_argument('--epochs', type=int, default=5)
17 | parser.add_argument('--pretrain_epochs', type=int, default=6)
18 |
19 | parser.add_argument("--ratio", type=float, default=0.05, help="ratio of outliers")
20 | parser.add_argument("--distance", type=int, default=2)
21 | parser.add_argument("--fraction", type=float, default=0.2)
22 | parser.add_argument("--obeserved_ratio", type=float, default=1.0)
23 |
24 | parser.add_argument("--device", type=str, default='cuda:0')
25 | parser.add_argument("--dataset", type=str, default='porto')
26 | parser.add_argument("--update_mode", type=str, default='pretrain')
27 |
28 | parser.add_argument("--train_num", type=int, default=80000) # 80000 200000
29 |
30 | parser.add_argument("--s1_size", type=int, default=2)
31 | parser.add_argument("--s2_size", type=int, default=4)
32 |
33 | parser.add_argument("--task", type=str, default='train')
34 |
35 | args = parser.parse_args()
36 |
37 | # python train.py --dataset porto --batch_size 1600 --pretrain_epochs 6
38 | # python train.py --dataset cd --batch_size 300 --pretrain_epochs 3 --epochs 4
39 |
--------------------------------------------------------------------------------
/temporal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class TemporalEmbedding:
5 | def __init__(self, device, model_path="./temporal_model/emb_128.pth"):
6 | self.model = TemporalVec().to(device)
7 | self.model_path = model_path
8 | self.device = device
9 |
10 | def __call__(self, x):
11 | self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
12 | with torch.no_grad():
13 | return self.model.encode(torch.Tensor(x).unsqueeze(0)).squeeze(0)
14 |
15 | class TemporalVec(nn.Module):
16 | def __init__(self, k=128, act="sin"):
17 | super(TemporalVec, self).__init__()
18 |
19 | if k % 2 == 0:
20 | k1 = k // 2
21 | k2 = k // 2
22 | else:
23 | k1 = k // 2
24 | k2 = k // 2 + 1
25 |
26 | self.fc1 = nn.Linear(6, k1)
27 | self.fc2 = nn.Linear(6, k2)
28 | self.d2 = nn.Dropout(0.3)
29 |
30 | if act == 'sin':
31 | self.activation = torch.sin
32 | else:
33 | self.activation = torch.cos
34 |
35 | self.fc3 = nn.Linear(k, k // 2)
36 | self.d3 = nn.Dropout(0.3)
37 | self.fc4 = nn.Linear(k // 2, 6)
38 | self.fc5 = torch.nn.Linear(6, 6)
39 |
40 | def forward(self, x):
41 | out1 = self.fc1(x)
42 | out2 = self.d2(self.activation(self.fc2(x)))
43 | out = torch.cat([out1, out2], 1)
44 | out = self.d3(self.fc3(out))
45 | out = self.fc4(out)
46 | out = self.fc5(out)
47 | return out
48 |
49 | def encode(self, x):
50 | out1 = self.fc1(x)
51 | out2 = self.activation(self.fc2(x))
52 | out = torch.cat([out1, out2], -1)
53 | return out
54 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MST-OATD
2 | Code for "Multi-Scale Detection of Anomalous Spatio-Temporal Trajectories in Evolving Trajectory Datasets"
3 | ### Requirements
4 | ```
5 | pip install -r requirements.txt
6 | ```
7 | ### Preprocessing
8 | - Step1: Download the Porto dataset (train.csv.zip) from [Porto](https://www.kaggle.com/c/pkdd-15-predict-taxi-service-trajectory-i/data), and the Chengdu dataset (Chengdu.zip) from [Chengdu](https://www.dropbox.com/scl/fi/w4jylj9het6x93btxud6o/Chengdu.zip?rlkey=w6x00pzyjk4z7fvxwhkryeq1l&dl=0).
9 | - Step2: Put the Porto data file in ../datasets/porto/, and unzip it as porto.csv. Put the unzipped Chengdu data in ../datasets/chengdu/.
10 | - Step3: Run preprocessing by
11 | ```
12 | mkdir -p data/
13 | cd preprocess
14 | python preprocess_.py
15 | cd ..
16 | mkdir logs models probs
17 | ```
18 | : porto or cd
19 |
20 | ### Generating ground truth
21 |
22 | ```
23 | python generate_outliers.py --distance 2 --fraction 0.2 --obeserved_ratio 1.0 --dataset
24 | ```
25 | distance is used to control the moving distance of outliers, fraction is the fraction of continuous outlier, obeserved_ratio is the ratio of the obeserved part of a trajectory.
26 | ### Training and testing
27 | ```
28 | python train.py --task train --dataset
29 | python train.py --task test --distance 2 --fraction 0.2 --obeserved_ratio 1.0 --dataset
30 | ```
31 | ### Training on evolving datasets
32 | ```
33 | python train_labels.py --dataset
34 | python train_update.py --update_mode pretrain --dataset --train_num
35 | ```
36 | update_mode contains three modes: pretrain, temporal, rank, is the number of trajectories used for evolving training.
37 |
38 | ### Citation
39 | Please kindly cite our work if you find our paper or codes helpful.
40 |
41 | link: https://dl.acm.org/doi/10.1145/3637528.3671874
42 | ```
43 | @inproceedings{wang2024multi,
44 | title={Multi-Scale Detection of Anomalous Spatio-Temporal Trajectories in Evolving Trajectory Datasets},
45 | author={Wang, Chenhao and Chen, Lisi and Shang, Shuo and Jensen, Christian S and Kalnis, Panos},
46 | booktitle={Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
47 | pages={2980--2990},
48 | year={2024}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/preprocess/preprocess_utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 |
4 | import networkx as nx
5 | import numpy as np
6 | import scipy.sparse as sparse
7 | from geopy import distance
8 |
9 |
10 | # Determine whether a point is in boundary
11 | def in_boundary(lat, lng, b):
12 | return b['min_lng'] < lng < b['max_lng'] and b['min_lat'] < lat < b['max_lat']
13 |
14 |
15 | # Cut long trajectories
16 | def cutting_trajs(traj, longest, shortest):
17 | cutted_trajs = []
18 | while len(traj) > longest:
19 | random_length = np.random.randint(shortest, longest)
20 | cutted_traj = traj[:random_length]
21 | cutted_trajs.append(cutted_traj)
22 | traj = traj[random_length:]
23 | return cutted_trajs
24 |
25 |
26 | # convert datetime to time vector
27 | def convert_date(str):
28 | timeArray = time.strptime(str, "%Y/%m/%d %H:%M:%S")
29 | t = [timeArray.tm_hour, timeArray.tm_min, timeArray.tm_sec, timeArray.tm_year, timeArray.tm_mon, timeArray.tm_mday]
30 | return t
31 |
32 |
33 | # Calculate timestamp gap
34 | def timestamp_gap(str1, str2):
35 | timestamp1 = datetime.datetime.strptime(str1, "%Y/%m/%d %H:%M:%S")
36 | timestamp2 = datetime.datetime.strptime(str2, "%Y/%m/%d %H:%M:%S")
37 | return (timestamp2 - timestamp1).total_seconds()
38 |
39 |
40 | # Map trajectories to grids
41 | def grid_mapping(boundary, grid_size):
42 | lat_dist = distance.distance((boundary['min_lat'], boundary['min_lng']),
43 | (boundary['max_lat'], boundary['min_lng'])).km
44 | lat_size = (boundary['max_lat'] - boundary['min_lat']) / lat_dist * grid_size
45 |
46 | lng_dist = distance.distance((boundary['min_lat'], boundary['min_lng']),
47 | (boundary['min_lat'], boundary['max_lng'])).km
48 | lng_size = (boundary['max_lng'] - boundary['min_lng']) / lng_dist * grid_size
49 |
50 | lat_grid_num = int(lat_dist / grid_size) + 1
51 | lng_grid_num = int(lng_dist / grid_size) + 1
52 | return lat_size, lng_size, lat_grid_num, lng_grid_num
53 |
54 |
55 | # Generate adjacency matrix and normalized degree matrix
56 | def generate_matrix(lat_grid_num, lng_grid_num):
57 | G = nx.grid_2d_graph(lat_grid_num, lng_grid_num, periodic=False)
58 | A = nx.adjacency_matrix(G)
59 | I = sparse.identity(lat_grid_num * lng_grid_num)
60 | D = np.diag(np.sum(A + I, axis=1))
61 | D = 1 / (np.sqrt(D) + 1e-10)
62 | D[D == 1e10] = 0.
63 | D = sparse.csr_matrix(D)
64 | return A + I, D
65 |
--------------------------------------------------------------------------------
/train_labels.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from config import args
8 |
9 |
10 | class LinearRegression(torch.nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 | self.linear = torch.nn.Linear(1, 1)
14 | self.activate = nn.Sigmoid()
15 |
16 | def forward(self, x):
17 | out = self.activate(self.linear(x))
18 | return out
19 |
20 |
21 | class Linear_Model:
22 | def __init__(self):
23 | self.learning_rate = 5e-4
24 | self.epoches = 100000
25 | self.loss_function = torch.nn.MSELoss()
26 | self.create_model()
27 |
28 | def create_model(self):
29 | self.model = LinearRegression().to(args.device)
30 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
31 |
32 | def train(self, num, data):
33 |
34 | x = data.unsqueeze(1)
35 | x = (x - torch.mean(x)) / torch.std(x)
36 | y = torch.arange(1, len(x) + 1).unsqueeze(1) / len(x)
37 |
38 | temp = 100
39 | for epoch in range(self.epoches):
40 | prediction = self.model(x.to(args.device))
41 | loss = self.loss_function(prediction, y.to(args.device))
42 |
43 | self.optimizer.zero_grad()
44 | loss.backward()
45 | self.optimizer.step()
46 |
47 | if (epoch + 1) % 10000 == 0:
48 | print("Epoch: {} loss: {}".format(epoch + 1, loss.item()))
49 | if (temp - loss) < 1e-7:
50 | break
51 | else:
52 | temp = loss
53 | torch.save(self.model.state_dict(), "probs/linear_{}_{}.pth".format(num, args.dataset))
54 |
55 | def test(self, num, data):
56 | x = data.unsqueeze(1)
57 | x = (x - torch.mean(x)) / torch.std(x)
58 |
59 | self.model.load_state_dict(torch.load("probs/linear_{}_{}.pth".format(num, args.dataset)))
60 | self.model.eval()
61 | prediction = self.model(x).cpu()
62 |
63 | return (prediction * len(x)).to(dtype=torch.int32).squeeze(1).detach().numpy()
64 |
65 |
66 | if __name__ == '__main__':
67 | pool = Pool(processes=10)
68 | for num in range(args.n_cluster):
69 | data = np.load('probs/probs_{}_{}.npy'.format(num, args.dataset))
70 | data = torch.Tensor(data)
71 | linear = Linear_Model()
72 | pool.apply_async(linear.train, (num, data))
73 | pool.close()
74 | pool.join()
75 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 |
5 | from config import args
6 | from mst_oatd_trainer import train_mst_oatd, MyDataset, seed_torch, collate_fn
7 |
8 |
9 | def main():
10 | train_trajs = np.load('./data/{}/train_data_init.npy'.format(args.dataset), allow_pickle=True)
11 | test_trajs = np.load('./data/{}/outliers_data_init_{}_{}_{}.npy'.format(args.dataset, args.distance, args.fraction,
12 | args.obeserved_ratio), allow_pickle=True)
13 | outliers_idx = np.load("./data/{}/outliers_idx_init_{}_{}_{}.npy".format(args.dataset, args.distance, args.fraction,
14 | args.obeserved_ratio), allow_pickle=True)
15 |
16 | train_data = MyDataset(train_trajs)
17 | test_data = MyDataset(test_trajs)
18 |
19 | labels = np.zeros(len(test_trajs))
20 | for i in outliers_idx:
21 | labels[i] = 1
22 | labels = labels
23 |
24 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn,
25 | num_workers=8, pin_memory=True)
26 | outliers_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn,
27 | num_workers=8, pin_memory=True)
28 |
29 | MST_OATD = train_mst_oatd(s_token_size, t_token_size, labels, train_loader, outliers_loader, args)
30 |
31 | if args.task == 'train':
32 |
33 | MST_OATD.logger.info("Start pretraining!")
34 |
35 | for epoch in range(args.pretrain_epochs):
36 | MST_OATD.pretrain(epoch)
37 |
38 | MST_OATD.train_gmm()
39 | MST_OATD.save_weights_for_MSTOATD()
40 |
41 | MST_OATD.logger.info("Start training!")
42 | MST_OATD.load_mst_oatd()
43 | for epoch in range(args.epochs):
44 | MST_OATD.train(epoch)
45 |
46 | if args.task == 'test':
47 |
48 | MST_OATD.logger.info('Start testing!')
49 | MST_OATD.logger.info("d = {}".format(args.distance) + ", " + chr(945) + " = {}".format(args.fraction) + ", "
50 | + chr(961) + " = {}".format(args.obeserved_ratio))
51 |
52 | checkpoint = torch.load(MST_OATD.path_checkpoint)
53 | MST_OATD.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s'])
54 | MST_OATD.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t'])
55 | pr_auc = MST_OATD.detection()
56 | pr_auc = "%.4f" % pr_auc
57 | MST_OATD.logger.info("PR_AUC: {}".format(pr_auc))
58 |
59 | if args.task == 'train':
60 | MST_OATD.train_gmm_update()
61 | z = MST_OATD.get_hidden()
62 | MST_OATD.get_prob(z.cpu())
63 |
64 |
65 | if __name__ == "__main__":
66 |
67 | if args.dataset == 'porto':
68 | s_token_size = 51 * 119
69 | t_token_size = 5760
70 |
71 | elif args.dataset == 'cd':
72 | s_token_size = 167 * 154
73 | t_token_size = 8640
74 |
75 | main()
76 |
--------------------------------------------------------------------------------
/preprocess/preprocess_cd.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import partial
3 | from multiprocessing import Pool, Manager
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import scipy.sparse as sparse
8 | from sklearn.model_selection import train_test_split
9 |
10 | from preprocess_utils import *
11 |
12 |
13 | def preprocess(file, traj_path, shortest, longest, boundary, lat_size, lng_size, lng_grid_num, convert_date,
14 | timestamp_gap, in_boundary, cutting_trajs, traj_nums, point_nums):
15 | np.random.seed(1234)
16 | # Read and sort trajectories based on id and timestamp
17 | data = pd.read_csv("{}/{}".format(traj_path, file), header=None)
18 | data.columns = ['id', 'lat', 'lon', 'state', 'timestamp']
19 | data = data.sort_values(by=['id', 'timestamp'])
20 | data = data[data['state'] == 1]
21 |
22 | trajs = []
23 | traj_seq = []
24 | valid = True
25 |
26 | pre_point = data.iloc[0]
27 |
28 | # Select trajectories
29 | for point in data.itertuples():
30 |
31 | if point.id == pre_point.id and timestamp_gap(pre_point.timestamp, point.timestamp) <= 20:
32 | if in_boundary(point.lat, point.lon, boundary):
33 | grid_i = int((point.lat - boundary['min_lat']) / lat_size)
34 | grid_j = int((point.lon - boundary['min_lng']) / lng_size)
35 | traj_seq.append([grid_i * lng_grid_num + grid_j, convert_date(point[5])])
36 | else:
37 | valid = False
38 |
39 | else:
40 | if valid:
41 | if shortest <= len(traj_seq) <= longest:
42 | trajs.append(traj_seq)
43 | elif len(traj_seq) > longest:
44 | trajs += cutting_trajs(traj_seq, longest, shortest)
45 |
46 | traj_seq = []
47 | valid = True
48 | pre_point = point
49 |
50 | traj_nums.append(len(trajs))
51 | point_nums.append(sum([len(traj) for traj in trajs]))
52 |
53 | train_data, test_data = train_test_split(trajs, test_size=0.2, random_state=42)
54 | np.save("../data/cd/train_data_{}.npy".format(file[:8]), np.array(train_data, dtype=object))
55 | np.save("../data/cd/test_data_{}.npy".format(file[:8]), np.array(test_data, dtype=object))
56 |
57 |
58 | # Parallel preprocess
59 | def batch_preprocess(path_list):
60 | manager = Manager()
61 | traj_nums = manager.list()
62 | point_nums = manager.list()
63 | pool = Pool(processes=10)
64 | pool.map(partial(preprocess, traj_path=traj_path, shortest=shortest, longest=longest, boundary=boundary,
65 | lat_size=lat_size, lng_size=lng_size, lng_grid_num=lng_grid_num, convert_date=convert_date,
66 | timestamp_gap=timestamp_gap, in_boundary=in_boundary, cutting_trajs=cutting_trajs,
67 | traj_nums=traj_nums, point_nums=point_nums), path_list)
68 | pool.close()
69 | pool.join()
70 |
71 | num_trajs = sum(traj_nums)
72 | num_points = sum(point_nums)
73 | print("Total trajectory num:", num_trajs)
74 | print("Total point num:", num_points)
75 |
76 |
77 | def merge(path_list):
78 | res_train = []
79 | res_test = []
80 |
81 | for file in path_list:
82 |
83 | train_trajs = np.load("../data/cd/train_data_{}.npy".format(file[:8]), allow_pickle=True)
84 | test_trajs = np.load("../data/cd/test_data_{}.npy".format(file[:8]), allow_pickle=True)
85 | res_train.append(train_trajs)
86 | res_test.append(test_trajs)
87 |
88 | res_train = np.concatenate(res_train, axis=0)
89 | res_test = np.concatenate(res_test, axis=0)
90 |
91 | return res_train, res_test
92 |
93 |
94 | def main():
95 | path_list = os.listdir(traj_path)
96 | path_list.sort(key=lambda x: x.split('.'))
97 | path_list = path_list[:10]
98 |
99 | batch_preprocess(path_list)
100 | train_data, test_data = merge(path_list[:3])
101 |
102 | np.save("../data/cd/train_data_init.npy", np.array(train_data, dtype=object))
103 | np.save("../data/cd/test_data_init.npy", np.array(test_data, dtype=object))
104 |
105 | print('Fnished!')
106 |
107 |
108 | if __name__ == "__main__":
109 | traj_path = "../../datasets/chengdu"
110 |
111 | grid_size = 0.1
112 | shortest, longest = 30, 100
113 | boundary = {'min_lat': 30.6, 'max_lat': 30.75, 'min_lng': 104, 'max_lng': 104.16}
114 |
115 | lat_size, lng_size, lat_grid_num, lng_grid_num = grid_mapping(boundary, grid_size)
116 | A, D = generate_matrix(lat_grid_num, lng_grid_num)
117 |
118 | sparse.save_npz('../data/cd/adj.npz', A)
119 | sparse.save_npz('../data/cd/d_norm.npz', D)
120 |
121 | print('Grid size:', (lat_grid_num, lng_grid_num))
122 | print('----------Preprocessing----------')
123 | main()
124 |
--------------------------------------------------------------------------------
/preprocess/preprocess_porto.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import random
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import scipy.sparse as sparse
8 | from sklearn.model_selection import train_test_split
9 |
10 | from preprocess_utils import *
11 |
12 |
13 | def time_convert(timestamp):
14 | return datetime.datetime.fromtimestamp(timestamp)
15 |
16 |
17 | def preprocess(trajectories, traj_num, point_num):
18 | trajs = [] # Preprocessed trajectories
19 |
20 | for traj in trajectories.itertuples():
21 |
22 | traj_seq = []
23 | valid = True # Flag to determine whether a trajectory is in boundary
24 |
25 | polyline = json.loads(traj.POLYLINE)
26 | timestamp = traj.TIMESTAMP
27 |
28 | if len(polyline) >= shortest:
29 | for lng, lat in polyline:
30 |
31 | if in_boundary(lat, lng, boundary):
32 | grid_i = int((lat - boundary['min_lat']) / lat_size)
33 | grid_j = int((lng - boundary['min_lng']) / lng_size)
34 |
35 | t = datetime.datetime.fromtimestamp(timestamp)
36 | t = [t.hour, t.minute, t.second, t.year, t.month, t.day] # Time vector
37 |
38 | traj_seq.append([int(grid_i * lng_grid_num + grid_j), t])
39 | timestamp += 15 # In porto dataset, the sampling rate is 15
40 |
41 | else:
42 | valid = False
43 | break
44 |
45 | # Randomly delete 30% trajectory points to make the sampling rate not fixed
46 | to_delete = set(random.sample(range(len(traj_seq)), int(len(traj_seq) * 0.3)))
47 | traj_seq = [item for index, item in enumerate(traj_seq) if index not in to_delete]
48 |
49 | # Lengths are limited from 20 to 50
50 | if valid:
51 | if len(traj_seq) <= longest:
52 | trajs.append(traj_seq)
53 | else:
54 | trajs += cutting_trajs(traj_seq, longest, shortest)
55 |
56 | traj_num += len(trajs)
57 |
58 | for traj in trajs:
59 | point_num += len(traj)
60 |
61 | return trajs, traj_num, point_num
62 |
63 |
64 | def main():
65 | # Read csv file
66 | trajectories = pd.read_csv("{}/{}.csv".format(data_dir, data_name), header=0, usecols=['POLYLINE', 'TIMESTAMP'])
67 | trajectories['datetime'] = trajectories['TIMESTAMP'].apply(time_convert)
68 |
69 | # Inititial dataset
70 | start_time = datetime.datetime(2013, 7, 1, 0, 0, 0)
71 | end_time = datetime.datetime(2013, 9, 1, 0, 0, 0)
72 |
73 | traj_num, point_num = 0, 0
74 |
75 | # Select trajectories from start time to end time
76 | trajs = trajectories[(trajectories['datetime'] >= start_time) & (trajectories['datetime'] < end_time)]
77 | preprocessed_trajs, traj_num, point_num = preprocess(trajs, traj_num, point_num)
78 | train_data, test_data = train_test_split(preprocessed_trajs, test_size=0.2, random_state=42)
79 |
80 | np.save("../data/porto/train_data_init.npy", np.array(train_data, dtype=object))
81 | np.save("../data/porto/test_data_init.npy", np.array(test_data, dtype=object))
82 |
83 | start_time = datetime.datetime(2013, 9, 1, 0, 0, 0)
84 |
85 | # Evolving dataset
86 | for month in range(1, 11):
87 | end_time = start_time + datetime.timedelta(days=30)
88 | trajs = trajectories[(trajectories['datetime'] >= start_time) & (trajectories['datetime'] < end_time)]
89 |
90 | preprocessed_trajs, traj_num, point_num = preprocess(trajs, traj_num, point_num)
91 | train_data, test_data = train_test_split(preprocessed_trajs, test_size=0.2, random_state=42)
92 |
93 | np.save("../data/porto/train_data_{}.npy".format(month), np.array(train_data, dtype=object))
94 | np.save("../data/porto/test_data_{}.npy".format(month), np.array(test_data, dtype=object))
95 |
96 | start_time = end_time
97 |
98 | # Dataset statistics
99 | print("Total trajectory num:", traj_num)
100 | print("Total point num:", point_num)
101 |
102 | print('Fnished!')
103 |
104 |
105 | if __name__ == '__main__':
106 | random.seed(1234)
107 | np.random.seed(1234)
108 |
109 | data_dir = '../../datasets/porto'
110 | data_name = "porto"
111 |
112 | boundary = {'min_lat': 41.140092, 'max_lat': 41.185969, 'min_lng': -8.690261, 'max_lng': -8.549155}
113 | grid_size = 0.1
114 | shortest, longest = 20, 50
115 |
116 | lat_size, lng_size, lat_grid_num, lng_grid_num = grid_mapping(boundary, grid_size)
117 | A, D = generate_matrix(lat_grid_num, lng_grid_num)
118 |
119 | sparse.save_npz('../data/porto/adj.npz', A)
120 | sparse.save_npz('../data/porto/d_norm.npz', D)
121 |
122 | print('Grid size:', (lat_grid_num, lng_grid_num))
123 | print('----------Preprocessing----------')
124 | main()
125 |
--------------------------------------------------------------------------------
/generate_outliers.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import os
4 | from datetime import timedelta
5 |
6 | import numpy as np
7 |
8 | from config import args
9 |
10 |
11 | # Trajectory location offset
12 | def perturb_point(point, level, offset=None):
13 | point, times = point[0], point[1]
14 | x, y = int(point // map_size[1]), int(point % map_size[1])
15 |
16 | if offset is None:
17 | offset = [[0, 1], [1, 0], [-1, 0], [0, -1], [1, 1], [-1, -1], [-1, 1], [1, -1]]
18 | x_offset, y_offset = offset[np.random.randint(0, len(offset))]
19 |
20 | else:
21 | x_offset, y_offset = offset
22 |
23 | if 0 <= x + x_offset * level < map_size[0] and 0 <= y + y_offset * level < map_size[1]:
24 | x += x_offset * level
25 | y += y_offset * level
26 |
27 | return [int(x * map_size[1] + y), times]
28 |
29 |
30 | def convert(point):
31 | x, y = int(point // map_size[1]), int(point % map_size[1])
32 | return [x, y]
33 |
34 |
35 | def distance(a, b):
36 | return math.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2)
37 |
38 |
39 | def time_calcuate(vec, s):
40 | a = datetime.datetime(vec[3], vec[4], vec[5], vec[0], vec[1], vec[2])
41 | t = a + timedelta(seconds=s)
42 | return [t.hour, t.minute, t.second, t.year, t.month, t.day]
43 |
44 |
45 | # Trajectory time offset
46 | def perturb_time(traj, st_loc, end_loc, time_offset, interval):
47 | for i in range(st_loc, end_loc):
48 | traj[i][1] = time_calcuate(traj[i][1], int((i - st_loc + 1) * time_offset * interval))
49 |
50 | for i in range(end_loc, len(traj)):
51 | traj[i][1] = time_calcuate(traj[i][1], int((end_loc - st_loc) * time_offset * interval))
52 | return traj
53 |
54 |
55 | def perturb_batch(batch_x, level, prob, selected_idx):
56 | noisy_batch_x = []
57 |
58 | if args.dataset == 'porto':
59 | interval = 15
60 | else:
61 | interval = 10
62 |
63 | for idx, traj in enumerate(batch_x):
64 |
65 | anomaly_len = int(len(traj) * prob)
66 | anomaly_st_loc = np.random.randint(1, len(traj) - anomaly_len - 1)
67 |
68 | if idx in selected_idx:
69 | anomaly_ed_loc = anomaly_st_loc + anomaly_len
70 |
71 | p_traj = traj[:anomaly_st_loc] + [perturb_point(p, level) for p in
72 | traj[anomaly_st_loc:anomaly_ed_loc]] + traj[anomaly_ed_loc:]
73 |
74 | dis = max(distance(convert(traj[anomaly_st_loc][0]), convert(traj[anomaly_ed_loc][0])), 1)
75 | time_offset = (level * 2) / dis
76 |
77 | p_traj = perturb_time(p_traj, anomaly_st_loc, anomaly_ed_loc, time_offset, interval)
78 |
79 | else:
80 | p_traj = traj
81 |
82 | p_traj = p_traj[:int(len(p_traj) * args.obeserved_ratio)]
83 | noisy_batch_x.append(p_traj)
84 |
85 | return noisy_batch_x
86 |
87 |
88 | def generate_outliers(trajs, ratio=args.ratio, level=args.distance, point_prob=args.fraction):
89 | traj_num = len(trajs)
90 | selected_idx = np.random.randint(0, traj_num, size=int(traj_num * ratio))
91 | new_trajs = perturb_batch(trajs, level, point_prob, selected_idx)
92 | return new_trajs, selected_idx
93 |
94 |
95 | if __name__ == '__main__':
96 | np.random.seed(1234)
97 | print("=========================")
98 | print("Dataset: " + args.dataset)
99 | print("d = {}".format(args.distance) + ", " + chr(945) + " = {}".format(args.fraction) + ", "
100 | + chr(961) + " = {}".format(args.obeserved_ratio))
101 |
102 | if args.dataset == 'porto':
103 | map_size = (51, 119)
104 | elif args.dataset == 'cd':
105 | map_size = (167, 154)
106 |
107 | data = np.load("./data/{}/test_data_init.npy".format(args.dataset), allow_pickle=True)
108 | outliers_trajs, outliers_idx = generate_outliers(data)
109 | outliers_trajs = np.array(outliers_trajs, dtype=object)
110 | outliers_idx = np.array(outliers_idx)
111 |
112 | np.save("./data/{}/outliers_data_init_{}_{}_{}.npy".format(args.dataset, args.distance, args.fraction,
113 | args.obeserved_ratio), outliers_trajs)
114 | np.save("./data/{}/outliers_idx_init_{}_{}_{}.npy".format(args.dataset, args.distance, args.fraction,
115 | args.obeserved_ratio), outliers_idx)
116 |
117 | if args.dataset == 'cd':
118 |
119 | traj_path = "../datasets/chengdu"
120 | path_list = os.listdir(traj_path)
121 | path_list.sort(key=lambda x: x.split('.'))
122 |
123 | for file in path_list[3: 10]:
124 | if file[-4:] == '.txt':
125 | data = np.load("./data/{}/test_data_{}.npy".format(args.dataset, file[:8]),
126 | allow_pickle=True)
127 | outliers_trajs, outliers_idx = generate_outliers(data)
128 | outliers_trajs = np.array(outliers_trajs, dtype=object)
129 | outliers_idx = np.array(outliers_idx)
130 |
131 | np.save("./data/{}/outliers_data_{}_{}_{}_{}.npy".format(args.dataset, file[:8], args.distance,
132 | args.fraction, args.obeserved_ratio), outliers_trajs)
133 | np.save("./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, file[:8], args.distance,
134 | args.fraction, args.obeserved_ratio), outliers_idx)
135 |
136 | if args.dataset == 'porto':
137 | for i in range(1, 11):
138 | data = np.load("./data/{}/test_data_{}.npy".format(args.dataset, i), allow_pickle=True)
139 | outliers_trajs, outliers_idx = generate_outliers(data)
140 | outliers_trajs = np.array(outliers_trajs, dtype=object)
141 | outliers_idx = np.array(outliers_idx)
142 |
143 | np.save("./data/{}/outliers_data_{}_{}_{}_{}.npy".format(args.dataset, i, args.distance,
144 | args.fraction, args.obeserved_ratio), outliers_trajs)
145 | np.save("./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, i, args.distance,
146 | args.fraction, args.obeserved_ratio), outliers_idx)
147 |
--------------------------------------------------------------------------------
/mst_oatd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sparse
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7 |
8 | from temporal import TemporalEmbedding
9 |
10 |
11 | class co_attention(nn.Module):
12 | def __init__(self, dim):
13 | super(co_attention, self).__init__()
14 |
15 | self.Wq_s = nn.Linear(dim, dim, bias=False)
16 | self.Wk_s = nn.Linear(dim, dim, bias=False)
17 | self.Wv_s = nn.Linear(dim, dim, bias=False)
18 |
19 | self.Wq_t = nn.Linear(dim, dim, bias=False)
20 | self.Wk_t = nn.Linear(dim, dim, bias=False)
21 | self.Wv_t = nn.Linear(dim, dim, bias=False)
22 |
23 | self.dim_k = dim ** 0.5
24 |
25 | self.FFN_s = nn.Sequential(
26 | nn.Linear(dim, dim),
27 | nn.ReLU(),
28 | nn.Linear(dim, dim),
29 | nn.Dropout(0.1)
30 | )
31 |
32 | self.FFN_t = nn.Sequential(
33 | nn.Linear(dim, dim),
34 | nn.ReLU(),
35 | nn.Linear(dim, dim),
36 | nn.Dropout(0.1)
37 | )
38 |
39 | self.layer_norm = nn.LayerNorm(dim, eps=1e-6)
40 |
41 | def forward(self, seq_s, seq_t):
42 | seq_t = seq_t.unsqueeze(2)
43 | seq_s = seq_s.unsqueeze(2)
44 |
45 | q_s, k_s, v_s = self.Wq_s(seq_t), self.Wk_s(seq_s), self.Wv_s(seq_s)
46 | q_t, k_t, v_t = self.Wq_t(seq_s), self.Wk_t(seq_t), self.Wv_t(seq_t)
47 |
48 | coatt_s = F.softmax(torch.matmul(q_s / self.dim_k, k_s.transpose(2, 3)), dim=-1)
49 | coatt_t = F.softmax(torch.matmul(q_t / self.dim_k, k_t.transpose(2, 3)), dim=-1)
50 |
51 | att_s = self.layer_norm(self.FFN_s(torch.matmul(coatt_s, v_s)) + torch.matmul(coatt_s, v_s))
52 | att_t = self.layer_norm(self.FFN_t(torch.matmul(coatt_t, v_t)) + torch.matmul(coatt_t, v_t))
53 |
54 | return att_s.squeeze(2), att_t.squeeze(2)
55 |
56 |
57 | class state_attention(nn.Module):
58 | def __init__(self, args):
59 | super(state_attention, self).__init__()
60 |
61 | self.w_omega = nn.Parameter(torch.Tensor(args.hidden_size, args.hidden_size))
62 | self.u_omega = nn.Parameter(torch.Tensor(args.hidden_size, 1))
63 | nn.init.uniform_(self.w_omega, -0.1, 0.1)
64 | nn.init.uniform_(self.u_omega, -0.1, 0.1)
65 |
66 | def forward(self, seq):
67 | u = torch.tanh(torch.matmul(seq, self.w_omega))
68 | att = torch.matmul(u, self.u_omega).squeeze()
69 | att_score = F.softmax(att, dim=1).unsqueeze(2)
70 | scored_outputs = seq * att_score
71 | return scored_outputs.sum(1)
72 |
73 |
74 | class MST_OATD(nn.Module):
75 | def __init__(self, token_size, token_size_out, args):
76 | super(MST_OATD, self).__init__()
77 |
78 | self.emb_size = args.embedding_size
79 | self.device = args.device
80 | self.n_cluster = args.n_cluster
81 | self.dataset = args.dataset
82 | self.s1_size = args.s1_size
83 | self.s2_size = args.s2_size
84 |
85 | self.pi_prior = nn.Parameter(torch.ones(args.n_cluster) / args.n_cluster)
86 | self.mu_prior = nn.Parameter(torch.randn(args.n_cluster, args.hidden_size))
87 | self.log_var_prior = nn.Parameter(torch.zeros(args.n_cluster, args.hidden_size))
88 |
89 | self.embedding = nn.Embedding(token_size, args.embedding_size)
90 |
91 | self.encoder_s1 = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True)
92 | self.encoder_s2 = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True)
93 | self.encoder_s3 = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True)
94 |
95 | self.decoder = nn.GRU(args.embedding_size * 2, args.hidden_size, 1, batch_first=True)
96 |
97 | self.fc_mu = nn.Linear(args.hidden_size, args.hidden_size)
98 | self.fc_logvar = nn.Linear(args.hidden_size, args.hidden_size)
99 | self.layer_norm = nn.LayerNorm(args.hidden_size)
100 |
101 | self.fc_out = nn.Linear(args.hidden_size, token_size_out)
102 |
103 | self.nodes = torch.arange(token_size, dtype=torch.long).to(args.device)
104 | self.adj = sparse.load_npz("data/{}/adj.npz".format(args.dataset))
105 | self.d_norm = sparse.load_npz("data/{}/d_norm.npz".format(args.dataset))
106 |
107 | if args.dataset == 'porto':
108 | self.V = nn.Parameter(torch.Tensor(token_size, token_size))
109 | else:
110 | self.V = nn.Parameter(torch.Tensor(args.embedding_size, args.embedding_size))
111 |
112 | self.W1 = nn.Parameter(torch.ones(1) / 3)
113 | self.W2 = nn.Parameter(torch.ones(1) / 3)
114 | self.W3 = nn.Parameter(torch.ones(1) / 3)
115 |
116 | self.co_attention = co_attention(args.embedding_size).to(args.device)
117 | self.d2v = TemporalEmbedding(args.device)
118 |
119 | self.w_omega = nn.Parameter(torch.Tensor(args.embedding_size * 2, args.embedding_size * 2))
120 | self.u_omega = nn.Parameter(torch.Tensor(args.embedding_size * 2, 1))
121 |
122 | nn.init.uniform_(self.V, -0.2, 0.2)
123 | nn.init.uniform_(self.w_omega, -0.1, 0.1)
124 | nn.init.uniform_(self.u_omega, -0.1, 0.1)
125 |
126 | self.state_att = state_attention(args)
127 | self.dataset = args.dataset
128 |
129 | def scale_process(self, e_inputs, scale_size, lengths):
130 | e_inputs_split = torch.mean(e_inputs.unfold(1, scale_size, scale_size), dim=3)
131 | e_inputs_split = self.attention_layer(e_inputs_split, lengths)
132 | e_inputs_split = pack_padded_sequence(e_inputs_split, lengths, batch_first=True, enforce_sorted=False)
133 | return e_inputs_split
134 |
135 | def Norm_A(self, A, D):
136 | return D.mm(A).mm(self.V).mm(D)
137 |
138 | def Norm_A_N(self, A, D):
139 | return D.mm(A).mm(D)
140 |
141 | def reparameterize(self, mu, log_var):
142 | std = torch.exp(log_var * 0.5)
143 | eps = torch.randn_like(std)
144 | return mu + eps * std
145 |
146 | def padding_mask(self, inp):
147 | return inp == 0
148 |
149 | def attention_layer(self, e_input, lengths):
150 | mask = self.getMask(lengths)
151 | u = torch.tanh(torch.matmul(e_input, self.w_omega))
152 | att = torch.matmul(u, self.u_omega).squeeze()
153 | att = att.masked_fill(mask == 0, -1e10)
154 | att_score = F.softmax(att, dim=1).unsqueeze(2)
155 | att_e_input = e_input * att_score
156 | return att_e_input
157 |
158 | def array2sparse(self, A):
159 | A = A.tocoo()
160 | values = A.data
161 | indices = np.vstack((A.row, A.col))
162 | i = torch.LongTensor(indices).to(self.device)
163 | v = torch.FloatTensor(values).to(self.device)
164 | A = torch.sparse_coo_tensor(i, v, torch.Size(A.shape), dtype=torch.float32)
165 | return A
166 |
167 | def getMask(self, seq_lengths):
168 | max_len = max(seq_lengths)
169 | mask = torch.ones((len(seq_lengths), max_len)).to(self.device)
170 |
171 | for i, l in enumerate(seq_lengths):
172 | if l < max_len:
173 | mask[i, l:] = 0
174 | return mask
175 |
176 | def forward(self, trajs, times, lengths, batch_size, mode, c):
177 |
178 | # spatial embedding
179 | adj = self.array2sparse(self.adj)
180 | d_norm = self.array2sparse(self.d_norm)
181 | if self.dataset == 'porto':
182 | H = self.Norm_A(adj, d_norm)
183 | nodes = H.mm(self.embedding(self.nodes))
184 | else:
185 | H = self.Norm_A_N(adj, d_norm)
186 | nodes = H.mm(self.embedding(self.nodes)).mm(self.V)
187 |
188 | s_inputs = torch.index_select(nodes, 0, trajs.flatten().to(self.device)). \
189 | reshape(batch_size, -1, self.emb_size)
190 |
191 | # temporal embedding
192 | t_inputs = self.d2v(times.to(self.device)).to(self.device)
193 |
194 | att_s, att_t = self.co_attention(s_inputs, t_inputs)
195 | st_inputs = torch.concat((att_s, att_t), dim=2)
196 | d_inputs = torch.cat((torch.zeros(batch_size, 1, self.emb_size * 2, dtype=torch.long).to(self.device),
197 | st_inputs[:, :-1, :]), dim=1) # [bs, seq_len, emb_size * 2]
198 |
199 | decoder_inputs = pack_padded_sequence(d_inputs, lengths, batch_first=True, enforce_sorted=False)
200 |
201 | if mode == "pretrain" or "train":
202 | encoder_inputs_s1 = pack_padded_sequence(self.attention_layer(st_inputs, lengths), lengths,
203 | batch_first=True, enforce_sorted=False)
204 | encoder_inputs_s2 = self.scale_process(st_inputs, self.s1_size, [int(i // self.s1_size) for i in lengths])
205 | encoder_inputs_s3 = self.scale_process(st_inputs, self.s2_size, [int(i // self.s2_size) for i in lengths])
206 |
207 | _, encoder_final_state_s1 = self.encoder_s1(encoder_inputs_s1)
208 | _, encoder_final_state_s2 = self.encoder_s2(encoder_inputs_s2)
209 | _, encoder_final_state_s3 = self.encoder_s3(encoder_inputs_s3)
210 |
211 | encoder_final_state = (self.W1 * encoder_final_state_s1 + self.W2 * encoder_final_state_s2
212 | + self.W3 * encoder_final_state_s3)
213 | sum_W = self.W1.data + self.W2.data + self.W3.data
214 | self.W1.data /= sum_W
215 | self.W2.data /= sum_W
216 | self.W3.data /= sum_W
217 |
218 | mu = self.fc_mu(encoder_final_state)
219 | logvar = self.fc_logvar(encoder_final_state)
220 | z = self.reparameterize(mu, logvar)
221 |
222 | decoder_outputs, _ = self.decoder(decoder_inputs, z)
223 | decoder_outputs, _ = pad_packed_sequence(decoder_outputs, batch_first=True)
224 |
225 | elif mode == "test":
226 | mu = torch.stack([self.mu_prior] * batch_size, dim=1)[c: c + 1]
227 | decoder_outputs, _ = self.decoder(decoder_inputs, mu)
228 | decoder_outputs, _ = pad_packed_sequence(decoder_outputs, batch_first=True)
229 | logvar, z = None, None
230 |
231 | output = self.fc_out(self.layer_norm(decoder_outputs))
232 |
233 | return output, mu, logvar, z
234 |
--------------------------------------------------------------------------------
/train_update.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | from sklearn.mixture import GaussianMixture
7 | from torch.utils.data import DataLoader
8 |
9 | from config import args
10 | from mst_oatd_trainer import train_mst_oatd, MyDataset, seed_torch, collate_fn
11 | from train_labels import Linear_Model
12 |
13 |
14 | def get_z(trajs):
15 | data = MyDataset(trajs)
16 | loader = DataLoader(dataset=data, batch_size=args.batch_size, shuffle=False,
17 | collate_fn=collate_fn, num_workers=4)
18 | MST_OATD_U.train_loader = loader
19 | return MST_OATD_U.get_hidden().cpu()
20 |
21 |
22 | def load_gmm():
23 | checkpoint = torch.load('./models/gmm_update_{}.pt'.format(args.dataset))
24 | gmm = GaussianMixture(n_components=args.n_cluster, covariance_type='diag')
25 | gmm.weights_ = checkpoint['gmm_update_weights']
26 | gmm.means_ = checkpoint['gmm_update_means']
27 | gmm.covariances_ = checkpoint['gmm_update_covariances']
28 | gmm.precisions_cholesky_ = checkpoint['gmm_update_precisions_cholesky']
29 | return gmm
30 |
31 |
32 | def get_index(trajs, cats_sample):
33 | z = get_z(trajs)
34 | cats = gmm.predict(z)
35 | index = [cat in cats_sample for cat in cats]
36 | trajs = trajs[index]
37 | return index, z[index], trajs
38 |
39 |
40 | def get_score(z):
41 | probs = gmm.predict_proba(z)
42 |
43 | idx = []
44 | linear = Linear_Model()
45 | for label in range(args.n_cluster):
46 | data = -probs[:, label]
47 | rank = linear.test(label, torch.Tensor(data).to(args.device))
48 | idx.append(rank)
49 | idx = np.array(idx).T
50 | idxs = np.argsort(idx, axis=1)
51 |
52 | return idxs
53 |
54 |
55 | def update_data(origin_trajs, train_trajs, cats_sample):
56 | _, z, train_trajs = get_index(train_trajs, cats_sample)
57 | idxs = get_score(z)
58 |
59 | max_idxs = idxs[:, 0]
60 | for i, traj in enumerate(train_trajs):
61 | max_idx = max_idxs[i]
62 | origin_trajs[max_idx].append(traj)
63 |
64 | min_c = args.n_cluster - 1
65 | min_idx = idxs[:, min_c][i]
66 |
67 | while not origin_trajs[min_idx]:
68 | min_c -= 1
69 | min_idx = idxs[:, min_c][i]
70 | origin_trajs[min_idx].pop(0)
71 |
72 | return np.array(sum(origin_trajs, []), dtype=object)
73 |
74 |
75 | def train_update(train_trajs, test_trajs, labels, i):
76 | train_data = MyDataset(train_trajs)
77 | test_data = MyDataset(test_trajs)
78 |
79 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True,
80 | collate_fn=collate_fn, num_workers=8)
81 | outliers_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False,
82 | collate_fn=collate_fn, num_workers=8)
83 |
84 | MST_OATD.train_loader = train_loader
85 | MST_OATD.outliers_loader = outliers_loader
86 | MST_OATD.labels = labels
87 |
88 | pr_auc = []
89 | for epoch in range(args.epochs):
90 | MST_OATD.train(epoch)
91 | results = MST_OATD.detection()
92 | pr_auc.append(results)
93 | results = "%.4f" % max(pr_auc)
94 | print("File {} PR_AUC:".format(i), results)
95 | return max(pr_auc)
96 |
97 |
98 | def test_update(test_trajs, labels, i):
99 | test_data = MyDataset(test_trajs)
100 | outliers_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False,
101 | collate_fn=collate_fn, num_workers=8)
102 |
103 | MST_OATD_U.outliers_loader = outliers_loader
104 | MST_OATD_U.labels = labels
105 |
106 | pr_auc = MST_OATD_U.detection()
107 | results = "%.4f" % pr_auc
108 | print("File {} PR_AUC:".format(i), results)
109 | return pr_auc
110 |
111 |
112 | # Get the type of trajectory and save the trajectory data in the n_cluster list.
113 | def get_category(trajs):
114 | z = get_z(trajs)
115 | c_labels = gmm.predict(z)
116 |
117 | origin_trajs = []
118 | for label in range(args.n_cluster):
119 | index = c_labels == label
120 | origin_trajs.append(trajs[index].tolist())
121 | return origin_trajs
122 |
123 |
124 | def main():
125 | random.seed(1234)
126 | train_trajs = np.load('./data/{}/train_data_init.npy'.format(args.dataset),
127 | allow_pickle=True)[-args.train_num:]
128 |
129 | if args.update_mode == 'rank':
130 | random_train_trajs = train_trajs
131 |
132 | all_pr_auc = []
133 |
134 | if args.dataset == 'porto':
135 | for i in range(1, 11):
136 | train_trajs_new = np.load('./data/{}/train_data_{}.npy'.format(args.dataset, i),
137 | allow_pickle=True)
138 | test_trajs = np.load(
139 | './data/{}/outliers_data_{}_{}_{}_{}.npy'.format(args.dataset, i, args.distance, args.fraction,
140 | args.obeserved_ratio),
141 | allow_pickle=True)
142 | outliers_idx = np.load(
143 | "./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, i, args.distance, args.fraction,
144 | args.obeserved_ratio),
145 | allow_pickle=True)
146 |
147 | labels = np.zeros(len(test_trajs))
148 | for idx in outliers_idx:
149 | labels[idx] = 1
150 |
151 | cats = list(range(0, args.n_cluster))
152 | cats_sample = random.sample(cats, args.n_cluster // 4)
153 | test_index, _, _ = get_index(test_trajs, cats_sample)
154 |
155 | if args.update_mode == 'temporal':
156 | train_index, _, _ = get_index(train_trajs_new, cats_sample)
157 | train_trajs = np.concatenate((train_trajs, train_trajs_new[train_index]))[-len(train_trajs):]
158 | pr_auc = train_update(train_trajs, test_trajs[test_index], labels[test_index], i)
159 |
160 | elif args.update_mode == 'rank':
161 | trajs = get_category(random_train_trajs)
162 | train_trajs = update_data(trajs, train_trajs_new, cats_sample)
163 | random_train_trajs = np.concatenate((random_train_trajs, train_trajs_new))[-len(train_trajs):]
164 | pr_auc = train_update(train_trajs, test_trajs[test_index], labels[test_index], i)
165 |
166 | elif args.update_mode == 'pretrain':
167 | pr_auc = test_update(test_trajs[test_index], labels[test_index], i)
168 |
169 | all_pr_auc.append(pr_auc)
170 |
171 | if args.dataset == 'cd':
172 | traj_path = "../datasets/chengdu"
173 | path_list = os.listdir(traj_path)
174 | path_list.sort(key=lambda x: x.split('.'))
175 | path_list = path_list[3: 10]
176 |
177 | for i in range(len(path_list)):
178 | train_trajs_new = np.load('./data/{}/train_data_{}.npy'.format(args.dataset, path_list[i][:8]),
179 | allow_pickle=True)
180 | print(len(train_trajs_new))
181 | test_trajs = np.load(
182 | './data/{}/outliers_data_{}_{}_{}_{}.npy'.format(args.dataset, path_list[i][:8], args.distance,
183 | args.fraction, args.obeserved_ratio),
184 | allow_pickle=True)
185 | outliers_idx = np.load(
186 | "./data/{}/outliers_idx_{}_{}_{}_{}.npy".format(args.dataset, path_list[i][:8], args.distance,
187 | args.fraction, args.obeserved_ratio),
188 | allow_pickle=True)
189 |
190 | labels = np.zeros(len(test_trajs))
191 | for idx in outliers_idx:
192 | labels[idx] = 1
193 |
194 | cats = list(range(0, args.n_cluster))
195 | cats_sample = random.sample(cats, args.n_cluster // 4)
196 | test_index, _, _ = get_index(test_trajs, cats_sample)
197 |
198 | test_trajs = test_trajs[test_index]
199 | labels = labels[test_index]
200 |
201 | if args.update_mode == 'temporal':
202 |
203 | train_index, _, _ = get_index(train_trajs_new, cats_sample)
204 | train_trajs_new = train_trajs_new[train_index]
205 |
206 | train_trajs = train_trajs[-(len(train_trajs) - len(train_trajs_new)):]
207 | train_trajs = np.concatenate((train_trajs, train_trajs_new))
208 | print('Trajecotory num:', len(train_trajs))
209 | pr_auc = train_update(train_trajs, test_trajs, labels, i)
210 |
211 | elif args.update_mode == 'rank':
212 |
213 | trajs = get_category(random_train_trajs)
214 | train_trajs = update_data(trajs, train_trajs_new, cats_sample)
215 |
216 | random_train_trajs = random_train_trajs[-(len(random_train_trajs) - len(train_trajs_new)):]
217 | random_train_trajs = np.concatenate((random_train_trajs, train_trajs_new))
218 | print('Trajecotory num:', len(train_trajs))
219 | pr_auc = train_update(train_trajs, test_trajs, labels, i)
220 |
221 | elif args.update_mode == 'pretrain':
222 | pr_auc = test_update(test_trajs, labels, i)
223 |
224 | all_pr_auc.append(pr_auc)
225 | print('------------------------')
226 | results = "%.4f" % (sum(all_pr_auc) / len(all_pr_auc))
227 | print('Average PR_AUC:', results)
228 | print('------------------------')
229 |
230 |
231 | if __name__ == "__main__":
232 | seed_torch(1234)
233 | print("===========================")
234 | print("Dataset:", args.dataset)
235 | print("Mode:", args.update_mode)
236 |
237 | if args.dataset == 'porto':
238 | s_token_size = 51 * 119
239 | t_token_size = 5760
240 | elif args.dataset == 'cd':
241 | s_token_size = 167 * 154
242 | t_token_size = 8640
243 |
244 | gmm = load_gmm()
245 |
246 | MST_OATD = train_mst_oatd(s_token_size, t_token_size, None, None, None, args)
247 |
248 | MST_OATD.mode = 'update'
249 | checkpoint = torch.load(MST_OATD.path_checkpoint)
250 | MST_OATD.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s'])
251 | MST_OATD.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t'])
252 |
253 | MST_OATD_U = train_mst_oatd(s_token_size, t_token_size, None, None, None, args)
254 |
255 | checkpoint_U = torch.load(MST_OATD_U.path_checkpoint)
256 | MST_OATD_U.mode = 'update'
257 | MST_OATD_U.MST_OATD_S.load_state_dict(checkpoint_U['model_state_dict_s'])
258 | MST_OATD_U.MST_OATD_T.load_state_dict(checkpoint_U['model_state_dict_t'])
259 | main()
260 |
--------------------------------------------------------------------------------
/mst_oatd_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from sklearn.mixture import GaussianMixture
10 | from torch.optim.lr_scheduler import StepLR
11 | from torch.utils.data import Dataset
12 |
13 | from logging_set import get_logger
14 | from mst_oatd import MST_OATD
15 | from utils import auc_score, make_mask, make_len_mask
16 |
17 |
18 | def collate_fn(batch):
19 | max_len = max(len(x) for x in batch)
20 | seq_lengths = list(map(len, batch))
21 | batch_trajs = [x + [[0, [0] * 6]] * (max_len - len(x)) for x in batch]
22 | return torch.LongTensor(np.array(batch_trajs, dtype=object)[:, :, 0].tolist()), \
23 | torch.Tensor(np.array(batch_trajs, dtype=object)[:, :, 1].tolist()), np.array(seq_lengths)
24 |
25 |
26 | def seed_torch(seed):
27 | random.seed(seed)
28 | os.environ['PYTHONHASHSEED'] = str(seed)
29 | np.random.seed(seed)
30 | torch.manual_seed(seed)
31 | torch.cuda.manual_seed(seed)
32 | torch.backends.cudnn.benchmark = False
33 | torch.backends.cudnn.deterministic = True
34 |
35 |
36 | class MyDataset(Dataset):
37 | def __init__(self, seqs):
38 | self.seqs = seqs
39 |
40 | def __len__(self):
41 | return len(self.seqs)
42 |
43 | def __getitem__(self, index):
44 | data_seqs = self.seqs[index]
45 | return data_seqs
46 |
47 |
48 | def time_convert(times, time_interval):
49 | return torch.Tensor((times[:, :, 2] + times[:, :, 1] * 60 + times[:, :, 0] * 3600) // time_interval).long()
50 |
51 |
52 | def savecheckpoint(state, file_name):
53 | torch.save(state, file_name)
54 |
55 |
56 | class train_mst_oatd:
57 | def __init__(self, s_token_size, t_token_size, labels, train_loader, outliers_loader, args):
58 |
59 | self.MST_OATD_S = MST_OATD(s_token_size, s_token_size, args).to(args.device)
60 | self.MST_OATD_T = MST_OATD(s_token_size, t_token_size, args).to(args.device)
61 |
62 | self.device = args.device
63 | self.dataset = args.dataset
64 | self.n_cluster = args.n_cluster
65 | self.hidden_size = args.hidden_size
66 |
67 | self.crit = nn.CrossEntropyLoss()
68 | self.detec = nn.CrossEntropyLoss(reduction='none')
69 |
70 | self.pretrain_optimizer_s = optim.AdamW([
71 | {'params': self.MST_OATD_S.parameters()},
72 | ], lr=args.pretrain_lr_s)
73 |
74 | self.pretrain_optimizer_t = optim.AdamW([
75 | {'params': self.MST_OATD_T.parameters()},
76 | ], lr=args.pretrain_lr_t)
77 |
78 | self.optimizer_s = optim.AdamW([
79 | {'params': self.MST_OATD_S.parameters()},
80 | ], lr=args.lr_s)
81 |
82 | self.optimizer_t = optim.Adam([
83 | {'params': self.MST_OATD_T.parameters()},
84 | ], lr=args.lr_t)
85 |
86 | self.lr_pretrain_s = StepLR(self.pretrain_optimizer_s, step_size=2, gamma=0.9)
87 | self.lr_pretrain_t = StepLR(self.pretrain_optimizer_t, step_size=2, gamma=0.9)
88 |
89 | self.train_loader = train_loader
90 | self.outliers_loader = outliers_loader
91 |
92 | self.pretrained_path = 'models/pretrain_mstoatd_{}.pth'.format(args.dataset)
93 | self.path_checkpoint = 'models/mstoatd_{}.pth'.format(args.dataset)
94 | self.gmm_path = "models/gmm_{}.pt".format(args.dataset)
95 | self.gmm_update_path = "models/gmm_update_{}.pt".format(args.dataset)
96 | self.logger = get_logger("./logs/{}.log".format(args.dataset))
97 |
98 | self.labels = labels
99 | if args.dataset == 'cd':
100 | self.time_interval = 10
101 | else:
102 | self.time_interval = 15
103 | self.mode = 'train'
104 |
105 | self.s1_size = args.s1_size
106 | self.s2_size = args.s2_size
107 |
108 | def pretrain(self, epoch):
109 | self.MST_OATD_S.train()
110 | self.MST_OATD_T.train()
111 | epo_loss = 0
112 |
113 | for batch in self.train_loader:
114 | trajs, times, seq_lengths = batch
115 | batch_size = len(trajs)
116 |
117 | mask = make_mask(make_len_mask(trajs)).to(self.device)
118 |
119 | self.pretrain_optimizer_s.zero_grad()
120 | self.pretrain_optimizer_t.zero_grad()
121 | output_s, _, _, _ = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1)
122 | output_t, _, _, _ = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "pretrain", -1)
123 |
124 | times = time_convert(times, self.time_interval)
125 |
126 | loss = self.crit(output_s[mask == 1], trajs.to(self.device)[mask == 1])
127 | loss += self.crit(output_t[mask == 1], times.to(self.device)[mask == 1])
128 |
129 | loss.backward()
130 |
131 | self.pretrain_optimizer_s.step()
132 | self.pretrain_optimizer_t.step()
133 | epo_loss += loss.item()
134 |
135 | self.lr_pretrain_s.step()
136 | self.lr_pretrain_t.step()
137 | epo_loss = "%.4f" % (epo_loss / len(self.train_loader))
138 | self.logger.info("Epoch {} pretrain loss: {}".format(epoch + 1, epo_loss))
139 | checkpoint = {"model_state_dict_s": self.MST_OATD_S.state_dict(),
140 | "model_state_dict_t": self.MST_OATD_T.state_dict()}
141 | torch.save(checkpoint, self.pretrained_path)
142 |
143 | def get_hidden(self):
144 | checkpoint = torch.load(self.path_checkpoint)
145 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s'])
146 | self.MST_OATD_S.eval()
147 | with torch.no_grad():
148 | z = []
149 | for batch in self.train_loader:
150 | trajs, times, seq_lengths = batch
151 | batch_size = len(trajs)
152 | _, _, _, hidden = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1)
153 | z.append(hidden.squeeze(0))
154 | z = torch.cat(z, dim=0)
155 | return z
156 |
157 | def train_gmm(self):
158 | self.MST_OATD_S.eval()
159 | self.MST_OATD_T.eval()
160 | checkpoint = torch.load(self.pretrained_path)
161 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s'])
162 | self.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t'])
163 |
164 | with torch.no_grad():
165 | z_s = []
166 | z_t = []
167 | for batch in self.train_loader:
168 | trajs, times, seq_lengths = batch
169 | batch_size = len(trajs)
170 | _, _, _, hidden_s = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1)
171 | _, _, _, hidden_t = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "pretrain", -1)
172 |
173 | z_s.append(hidden_s.squeeze(0))
174 | z_t.append(hidden_t.squeeze(0))
175 | z_s = torch.cat(z_s, dim=0)
176 | z_t = torch.cat(z_t, dim=0)
177 |
178 | self.logger.info('Start fitting gaussian mixture model!')
179 |
180 | self.gmm_s = GaussianMixture(n_components=self.n_cluster, covariance_type="diag", n_init=1)
181 | self.gmm_s.fit(z_s.cpu().numpy())
182 |
183 | self.gmm_t = GaussianMixture(n_components=self.n_cluster, covariance_type="diag", n_init=1)
184 | self.gmm_t.fit(z_t.cpu().numpy())
185 |
186 | def save_weights_for_MSTOATD(self):
187 | savecheckpoint({"gmm_s_mu_prior": self.gmm_s.means_,
188 | "gmm_s_pi_prior": self.gmm_s.weights_,
189 | "gmm_s_logvar_prior": self.gmm_s.covariances_,
190 | "gmm_t_mu_prior": self.gmm_t.means_,
191 | "gmm_t_pi_prior": self.gmm_t.weights_,
192 | "gmms_t_logvar_prior": self.gmm_t.covariances_}, self.gmm_path)
193 |
194 | def train_gmm_update(self):
195 |
196 | checkpoint = torch.load(self.path_checkpoint)
197 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s'])
198 | self.MST_OATD_S.eval()
199 |
200 | with torch.no_grad():
201 | z = []
202 | for batch in self.train_loader:
203 | trajs, times, seq_lengths = batch
204 | batch_size = len(trajs)
205 | _, _, _, hidden = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "pretrain", -1)
206 | z.append(hidden.squeeze(0))
207 | z = torch.cat(z, dim=0)
208 |
209 | self.logger.info('Start fitting gaussian mixture model!')
210 |
211 | self.gmm = GaussianMixture(n_components=self.n_cluster, covariance_type="diag", n_init=3)
212 | self.gmm.fit(z.cpu().numpy())
213 |
214 | savecheckpoint({"gmm_update_weights": self.gmm.weights_,
215 | "gmm_update_means": self.gmm.means_,
216 | "gmm_update_covariances": self.gmm.covariances_,
217 | "gmm_update_precisions_cholesky": self.gmm.precisions_cholesky_}, self.gmm_update_path)
218 |
219 | def train(self, epoch):
220 | self.MST_OATD_S.train()
221 | self.MST_OATD_T.train()
222 | total_loss = 0
223 | for batch in self.train_loader:
224 | trajs, times, seq_lengths = batch
225 | batch_size = len(trajs)
226 |
227 | mask = make_mask(make_len_mask(trajs)).to(self.device)
228 |
229 | self.optimizer_s.zero_grad()
230 | self.optimizer_t.zero_grad()
231 |
232 | x_hat_s, mu_s, log_var_s, z_s = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "train", -1)
233 | loss = self.Loss(x_hat_s, trajs.to(self.device), mu_s.squeeze(0), log_var_s.squeeze(0),
234 | z_s.squeeze(0), 's', mask)
235 | x_hat_t, mu_t, log_var_t, z_t = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "train", -1)
236 | times = time_convert(times, self.time_interval)
237 | loss += self.Loss(x_hat_t, times.to(self.device), mu_t.squeeze(0), log_var_t.squeeze(0),
238 | z_t.squeeze(0), 't', mask)
239 |
240 | loss.backward()
241 | self.optimizer_s.step()
242 | self.optimizer_t.step()
243 | total_loss += loss.item()
244 |
245 | if self.mode == "train":
246 | total_loss = "%.4f" % (total_loss / len(self.train_loader))
247 | self.logger.info('Epoch {} loss: {}'.format(epoch + 1, total_loss))
248 | checkpoint = {"model_state_dict_s": self.MST_OATD_S.state_dict(),
249 | "model_state_dict_t": self.MST_OATD_T.state_dict()}
250 | torch.save(checkpoint, self.path_checkpoint)
251 |
252 | def detection(self):
253 |
254 | self.MST_OATD_S.eval()
255 | all_likelihood_s = []
256 | self.MST_OATD_T.eval()
257 | all_likelihood_t = []
258 |
259 | with torch.no_grad():
260 |
261 | for batch in self.outliers_loader:
262 | trajs, times, seq_lengths = batch
263 | batch_size = len(trajs)
264 | mask = make_mask(make_len_mask(trajs)).to(self.device)
265 | times_token = time_convert(times, self.time_interval)
266 |
267 | c_likelihood_s = []
268 | c_likelihood_t = []
269 |
270 | for c in range(self.n_cluster):
271 | output_s, _, _, _ = self.MST_OATD_S(trajs, times, seq_lengths, batch_size, "test", c)
272 | likelihood_s = - self.detec(output_s.reshape(-1, output_s.shape[-1]),
273 | trajs.to(self.device).reshape(-1))
274 | likelihood_s = torch.exp(
275 | torch.sum(mask * (likelihood_s.reshape(batch_size, -1)), dim=-1) / torch.sum(mask, 1))
276 |
277 | output_t, _, _, _ = self.MST_OATD_T(trajs, times, seq_lengths, batch_size, "test", c)
278 | likelihood_t = - self.detec(output_t.reshape(-1, output_t.shape[-1]),
279 | times_token.to(self.device).reshape(-1))
280 | likelihood_t = torch.exp(
281 | torch.sum(mask * (likelihood_t.reshape(batch_size, -1)), dim=-1) / torch.sum(mask, 1))
282 |
283 | c_likelihood_s.append(likelihood_s.unsqueeze(0))
284 | c_likelihood_t.append(likelihood_t.unsqueeze(0))
285 |
286 | all_likelihood_s.append(torch.cat(c_likelihood_s).max(0)[0])
287 | all_likelihood_t.append(torch.cat(c_likelihood_t).max(0)[0])
288 |
289 | likelihood_s = torch.cat(all_likelihood_s, dim=0)
290 | likelihood_t = torch.cat(all_likelihood_t, dim=0)
291 |
292 | pr_auc = auc_score(self.labels, (1 - likelihood_s * likelihood_t).cpu().detach().numpy())
293 | return pr_auc
294 |
295 | def gaussian_pdf_log(self, x, mu, log_var):
296 | return -0.5 * (torch.sum(np.log(np.pi * 2) + log_var + (x - mu).pow(2) / torch.exp(log_var), 1))
297 |
298 | def gaussian_pdfs_log(self, x, mus, log_vars):
299 | G = []
300 | for c in range(self.n_cluster):
301 | G.append(self.gaussian_pdf_log(x, mus[c:c + 1, :], log_vars[c:c + 1, :]).view(-1, 1))
302 | return torch.cat(G, 1)
303 |
304 | def Loss(self, x_hat, targets, z_mu, z_sigma2_log, z, mode, mask):
305 | if mode == 's':
306 | pi = self.MST_OATD_S.pi_prior
307 | log_sigma2_c = self.MST_OATD_S.log_var_prior
308 | mu_c = self.MST_OATD_S.mu_prior
309 | elif mode == 't':
310 | pi = self.MST_OATD_T.pi_prior
311 | log_sigma2_c = self.MST_OATD_T.log_var_prior
312 | mu_c = self.MST_OATD_T.mu_prior
313 |
314 | reconstruction_loss = self.crit(x_hat[mask == 1], targets[mask == 1])
315 |
316 | gaussian_loss = torch.mean(torch.mean(self.gaussian_pdf_log(z, z_mu, z_sigma2_log).unsqueeze(1) -
317 | self.gaussian_pdfs_log(z, mu_c, log_sigma2_c), dim=1), dim=-1).mean()
318 |
319 | pi = F.softmax(pi, dim=-1)
320 | z = z.unsqueeze(1)
321 | mu_c = mu_c.unsqueeze(0)
322 | log_sigma2_c = log_sigma2_c.unsqueeze(0)
323 |
324 | logits = - torch.sum(torch.pow(z - mu_c, 2) / torch.exp(log_sigma2_c), dim=-1)
325 | logits = F.softmax(logits, dim=-1) + 1e-10
326 | category_loss = torch.mean(torch.sum(logits * (torch.log(logits) - torch.log(pi).unsqueeze(0)), dim=-1))
327 |
328 | loss = reconstruction_loss + gaussian_loss / self.hidden_size + category_loss * 0.1
329 | return loss
330 |
331 | def load_mst_oatd(self):
332 | checkpoint = torch.load(self.pretrained_path)
333 | self.MST_OATD_S.load_state_dict(checkpoint['model_state_dict_s'])
334 | self.MST_OATD_T.load_state_dict(checkpoint['model_state_dict_t'])
335 |
336 | gmm_params = torch.load(self.gmm_path)
337 |
338 | self.MST_OATD_S.pi_prior.data = torch.from_numpy(gmm_params['gmm_s_pi_prior']).to(self.device)
339 | self.MST_OATD_S.mu_prior.data = torch.from_numpy(gmm_params['gmm_s_mu_prior']).to(self.device)
340 | self.MST_OATD_S.log_var_prior.data = torch.from_numpy(gmm_params['gmm_s_logvar_prior']).to(self.device)
341 |
342 | self.MST_OATD_T.pi_prior.data = torch.from_numpy(gmm_params['gmm_t_pi_prior']).to(self.device)
343 | self.MST_OATD_T.mu_prior.data = torch.from_numpy(gmm_params['gmm_t_mu_prior']).to(self.device)
344 | self.MST_OATD_T.log_var_prior.data = torch.from_numpy(gmm_params['gmms_t_logvar_prior']).to(self.device)
345 |
346 | def get_prob(self, z):
347 | gmm = GaussianMixture(n_components=self.n_cluster, covariance_type='diag')
348 | gmm_params = torch.load(self.gmm_update_path)
349 | gmm.precisions_cholesky_ = gmm_params['gmm_update_precisions_cholesky']
350 | gmm.weights_ = gmm_params['gmm_update_weights']
351 | gmm.means_ = gmm_params['gmm_update_means']
352 | gmm.covariances_ = gmm_params['gmm_update_covariances']
353 |
354 | probs = gmm.predict_proba(z)
355 |
356 | for label in range(self.n_cluster):
357 | np.save('probs/probs_{}_{}.npy'.format(label, self.dataset), np.sort(-probs[:, label]))
358 |
--------------------------------------------------------------------------------