├── .gitignore ├── LICENSE ├── README.md ├── adone_experiment.py ├── anomaly_insert.py ├── anomalydae_experiment.py ├── data └── .gitignore ├── data_finefoods.py ├── data_finefoods_small.py ├── data_movies.py ├── data_movies_small.py ├── data_reddit.py ├── data_wikipedia.py ├── dominant_experiment.py ├── extract_movies.py ├── isoforest_experiment.py ├── models ├── conv.py ├── conv_sample.py ├── data.py ├── loss.py ├── net.py ├── net_sample.py ├── sampler.py └── score.py ├── requirements.txt ├── results └── .gitignore ├── storage └── .gitignore ├── train_full_experiment.py ├── train_sample_experiment.py └── utils ├── seed.py ├── sparse_combine.py ├── sprand.py └── sum_dict.py /.gitignore: -------------------------------------------------------------------------------- 1 | # OS generated files # 2 | ###################### 3 | .DS_Store 4 | .DS_Store? 5 | 6 | # python generated files # 7 | ########################## 8 | *.pyc 9 | */venv/ 10 | 11 | # ide generated directories 12 | .cache 13 | .vscode 14 | 15 | # pytest generated files # 16 | ########################## 17 | .pytest_cache 18 | 19 | # jupyter generated files # 20 | ########################## 21 | .ipynb_checkpoints/ 22 | 23 | # emacs autosave files 24 | \#* 25 | *~ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interaction-Focused Anomaly Detection on Bipartite Node-and-Edge-Attributed Graphs 2 | 3 | This repository contains the experimental source code of the [*Interaction-Focused Anomaly Detection on Bipartite Node-and-Edge-Attributed Graphs*](https://engineering.grab.com/graph-anomaly-model) paper presented at the [International Joint Conference on Neural Networks (IJCNN) 2023](https://2023.ijcnn.org/). 4 | 5 | Authors: [Rizal Fathony](mailto:rizal.fathony@grab.com), [Jenn Ng](mailto:jenn.ng@grab.com), and [Jia Chen](mailto:jia.chen@grab.com). 6 | 7 | ## Abstract 8 | 9 | Many anomaly detection applications naturally produce datasets that can be represented as bipartite graphs (user–interaction–item graphs). These graph datasets are usually supplied with rich information on both the entities (nodes) and the interactions (edges). Unfortunately, previous graph neural network anomaly models are unable to fully capture the rich information and produce high-performing detections on these graphs, as they mostly focus on homogeneous graphs and node attributes only. To overcome the problem, we propose a new graph anomaly detection model that focuses on the rich interactions in bipartite graphs. Specifically, our model takes a bipartite node-and-edge-attributed graph and produces anomaly scores for each of its edges and then for each of its bipartite nodes. We design our model as an autoencoder-type model with a customized encoder and decoder to facilitate the compression of node features, edge features, and graph structure into node-level latent representations. The reconstruction errors of each edge and node are then leveraged to spot the anomalies. Our network architecture is scalable, enabling large real-world applications. Finally, we demonstrate that our method significantly outperforms previous anomaly detection methods in the experiments. 10 | 11 | ## Setup 12 | 13 | 1. Install the required packages using: 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 2. Download the datasets. 18 | 19 | - `wikipedia` and `reddit`: 20 | ``` 21 | wget -P data/ http://snap.stanford.edu/jodie/wikipedia.csv 22 | wget -P data/ http://snap.stanford.edu/jodie/reddit.csv 23 | ``` 24 | 25 | - `finefoods`: Download from [here](https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews?select=Reviews.csv) to `data` folder. Rename the file to `finefoods.csv`. 26 | 27 | - `movies`: Download from [here](https://snap.stanford.edu/data/web-Movies.html). Extract and rename the file to `movies.txt`. Generate the `.csv` file by running `python extract_movies.py`. 28 | 29 | 30 | ## Construct Graph Datasets 31 | 32 | We construct the graph datasets by loading the csv and construct PyG graph data. We then inject anomalies into the dataset. For each dataset, please run: 33 | - `wikipedia` dataset: `python data_wikipedia.py` 34 | - `reddit` dataset: `python data_reddit.py` 35 | 36 | - `finefoods-large` dataset: `python data_finefoods.py` 37 | - `finefoods-small` dataset: `python data_finefoods_small.py` 38 | - `movies-large` dataset: `python data_movies.py` 39 | - `movies-small` dataset: `python data_movies_small.py` 40 | 41 | `Note`: for `finefoods` and `movies`, we use `sentence-transformer` to generate features from the review text. Running the graph construction on a machine with GPU support is recommended. The size of `finefoods` and `movies` is also quite large. Therefore, a machine with large memory size is required (60GB or 120GB). 42 | 43 | The script will convert the csv files into PyG graph format, and constrcut 10 different copies of the graph by injecting random anomalies into the graph via `anomaly_insert.py`. Each graph instance will have different sets of anomalies. 44 | 45 | ## Run Experiment 46 | 47 | To run the experiments, please execute the corresponding file for each model. 48 | 49 | 1. `GrapBEAN`: 50 | ``` 51 | python train_full_experiment.py --name wikipedia_anomaly --id 0 52 | ``` 53 | 54 | 1. `GrapBEAN` with neighborhood sampling: 55 | ``` 56 | python train_sample_experiment.py --name wikipedia_anomaly --id 0 --batch-size 128 57 | ``` 58 | 59 | 1. `IsolationForest`: 60 | ``` 61 | python isoforest_experiment.py --name wikipedia_anomaly --id 0 62 | ``` 63 | 64 | 1. `DOMINANT`: 65 | ``` 66 | python dominant_experiment.py --name wikipedia_anomaly --id 0 67 | ``` 68 | 69 | 1. `AnomalyDAE`: 70 | ``` 71 | python anomalydae_experiment.py --name wikipedia_anomaly --id 0 72 | ``` 73 | 74 | 1. `AdONE`: 75 | ``` 76 | python adone_experiment.py --name wikipedia_anomaly --id 0 77 | ``` 78 | 79 | The argument `--name` indicates which dataset we want the model run on, with the format of `{dataset_name}_anomaly`. Additional arguments are also available depending on the models. 80 | 81 | - Arguments for **all** models. 82 | ``` 83 | --name : dataset name 84 | --id : which instance of anomaly injected graph [0-9] 85 | ``` 86 | - Arguments for `DOMINANT`, `AnomalyDAE`, `AdONE`, and `GraphBEAN`. 87 | ``` 88 | --n-epoch : number of epoch in the training [default: 50] 89 | --lr : learning rate [default: 1e-2] 90 | ``` 91 | - Arguments for `DOMINANT` and `AnomalyDAE`. 92 | ``` 93 | --alpha : balance parameter [default: 0.8] 94 | ``` 95 | - Arguments for `GraphBEAN` (full and sample training). 96 | ``` 97 | --eta : structure decoder loss weight [default: 0.2] 98 | --score-agg : aggregation method for node anomaly score 99 | (max or mean) [default: max] 100 | --scheduler-milestones : milestones for learning scheduler [default: []] 101 | ``` 102 | - Arguments for `GraphBEAN` (sample training). 103 | ``` 104 | --batch-size : number of target nodes in one batch [default: 2048] 105 | --num-neighbors-u : number of neighbors sampled for node u [default: 10] 106 | --num-neighbors-v : number of neighbors sampled for node v [default: 10] 107 | --num-workers : number of workers in dataloader [default: 0] 108 | suggestion: set it as the number of available cores 109 | ``` 110 | 111 | Running the experiments on a machine with GPU support is recommended for all models except IsolationForest. 112 | 113 | ## License 114 | 115 | This repository is licenced under the [MIT License](LICENSE). 116 | 117 | ## Citation 118 | 119 | If you use this repository for academic purpose, please cite the following paper: 120 | 121 | 122 | > R. Fathony, J. Ng and J. Chen, "Interaction-Focused Anomaly Detection on Bipartite Node-and-Edge-Attributed Graphs," 2023 International Joint Conference on Neural Networks (IJCNN), Gold Coast, Australia, 2023, pp. 1-10, doi: 10.1109/IJCNN54540.2023.10191331. 123 | -------------------------------------------------------------------------------- /adone_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import sys 5 | from sklearn.metrics import roc_curve, precision_recall_curve, auc 6 | from data_finefoods import load_graph 7 | 8 | import argparse 9 | import os 10 | 11 | import torch 12 | from torch_geometric.data import Data 13 | from torch_scatter import scatter 14 | 15 | from utils.seed import seed_all 16 | 17 | # train a detector 18 | from pygod.models import AdONE 19 | 20 | # %% args 21 | 22 | parser = argparse.ArgumentParser(description="AdONE") 23 | parser.add_argument("--name", type=str, default="wikipedia_anomaly", help="name") 24 | parser.add_argument( 25 | "--key", type=str, default="graph_anomaly_list", help="key to the data" 26 | ) 27 | parser.add_argument("--id", type=int, default=0, help="id to the data") 28 | parser.add_argument("--n-epoch", type=int, default=200, help="number of epoch") 29 | parser.add_argument( 30 | "--num-neighbors", type=int, default=-1, help="number of neighbors for node" 31 | ) 32 | parser.add_argument("--batch-size", type=int, default=0, help="batch size") 33 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") 34 | parser.add_argument("--gpu", type=int, default=0, help="gpu number") 35 | 36 | args1 = vars(parser.parse_args()) 37 | 38 | args2 = { 39 | "seed": 0, 40 | "hidden_channels": 32, 41 | "dropout_prob": 0.0, 42 | } 43 | 44 | args = {**args1, **args2} 45 | 46 | seed_all(args["seed"]) 47 | 48 | result_dir = "results/" 49 | 50 | # %% data 51 | data = load_graph(args["name"], args["key"], args["id"]) 52 | 53 | u_ch = data.xu.shape[1] 54 | v_ch = data.xv.shape[1] 55 | e_ch = data.xe.shape[1] 56 | 57 | print( 58 | f"Data dimension: U node = {data.xu.shape}; V node = {data.xv.shape}; E edge = {data.xe.shape}; \n" 59 | ) 60 | 61 | # %% model 62 | 63 | xu, xv = data.xu, data.xv 64 | xe, adj = data.xe, data.adj 65 | yu, yv, ye = data.yu, data.yv, data.ye 66 | 67 | 68 | # %% to homogen 69 | nu = xu.shape[0] 70 | nv = xv.shape[0] 71 | nn = nu + nv 72 | 73 | # to homogen 74 | row_h = torch.cat([adj.storage.row(), adj.storage.col() + nu]) 75 | col_h = torch.cat([adj.storage.col() + nu, adj.storage.row()]) 76 | edge_index_h = torch.stack([row_h, col_h]) 77 | xuh = torch.cat( 78 | [ 79 | scatter(xe, adj.storage.row(), dim=0, reduce="max"), 80 | scatter(xe, adj.storage.row(), dim=0, reduce="mean"), 81 | ], 82 | dim=1, 83 | ) 84 | xvh = torch.cat( 85 | [ 86 | scatter(xe, adj.storage.col(), dim=0, reduce="max"), 87 | scatter(xe, adj.storage.col(), dim=0, reduce="mean"), 88 | ], 89 | dim=1, 90 | ) 91 | xh = torch.cat([xuh, xvh], dim=0) 92 | yh = torch.cat([yu, yv], dim=0) 93 | data_h = Data(x=xh, edge_index=edge_index_h, y=yh) 94 | 95 | # %% model 96 | 97 | device = torch.device(f'cuda:{args["gpu"]}' if torch.cuda.is_available() else "cpu") 98 | 99 | model = AdONE( 100 | hid_dim=args["hidden_channels"], 101 | dropout=args["dropout_prob"], 102 | epoch=args["n_epoch"], 103 | lr=args["lr"], 104 | verbose=True, 105 | gpu=args["gpu"], 106 | batch_size=args["batch_size"], 107 | num_neigh=args["num_neighbors"], 108 | ) 109 | 110 | print(args) 111 | print() 112 | 113 | 114 | def auc_eval(pred, y): 115 | 116 | rc_curve = roc_curve(y, pred) 117 | pr_curve = precision_recall_curve(y, pred) 118 | roc_auc = auc(rc_curve[0], rc_curve[1]) 119 | pr_auc = auc(pr_curve[1], pr_curve[0]) 120 | 121 | return roc_auc, pr_auc, rc_curve, pr_curve 122 | 123 | 124 | # %% run training 125 | 126 | model.fit(data_h, yh) 127 | score = model.decision_scores_ 128 | 129 | score_u = score[:nu] 130 | score_v = score[nu:] 131 | score_e_u = score_u[adj.storage.row().numpy()] 132 | score_e_v = score_v[adj.storage.col().numpy()] 133 | score_e = (score_e_u + score_e_v) / 2 134 | 135 | u_roc_auc, u_pr_auc, u_rc_curve, u_pr_curve = auc_eval(score_u, yu.numpy()) 136 | v_roc_auc, v_pr_auc, v_rc_curve, v_pr_curve = auc_eval(score_v, yv.numpy()) 137 | e_roc_auc, e_pr_auc, e_rc_curve, e_pr_curve = auc_eval(score_e, ye.numpy()) 138 | 139 | print( 140 | f"Eval | " 141 | + f"u auc-roc: {u_roc_auc:.4f}, v auc-roc: {v_roc_auc:.4f}, e auc-roc: {e_roc_auc:.4f} | " 142 | + f"u auc-pr {u_pr_auc:.4f}, v auc-pr {v_pr_auc:.4f}, e auc-pr {e_pr_auc:.4f}" 143 | ) 144 | 145 | auc_metrics = { 146 | "u_roc_auc": u_roc_auc, 147 | "u_pr_auc": u_pr_auc, 148 | "v_roc_auc": v_roc_auc, 149 | "v_pr_auc": v_pr_auc, 150 | "e_roc_auc": e_roc_auc, 151 | "e_pr_auc": e_pr_auc, 152 | "u_roc_curve": u_rc_curve, 153 | "u_pr_curve": u_pr_curve, 154 | "v_roc_curve": v_rc_curve, 155 | "v_pr_curve": v_pr_curve, 156 | "e_roc_curve": e_rc_curve, 157 | "e_pr_curve": e_pr_curve, 158 | } 159 | anomaly_score = {"score_u": score_u, "score_v": score_v, "score_e": score_e} 160 | 161 | model_stored = { 162 | "args": args, 163 | "auc_metrics": auc_metrics, 164 | "state_dict": model.model.state_dict(), 165 | } 166 | output_stored = {"args": args, "anomaly_score": anomaly_score} 167 | 168 | print("Saving current results...") 169 | torch.save( 170 | model_stored, 171 | os.path.join(result_dir, f"adone-{args['name']}-{args['id']}-model.th"), 172 | ) 173 | torch.save( 174 | output_stored, 175 | os.path.join(result_dir, f"adone-{args['name']}-{args['id']}-output.th"), 176 | ) 177 | 178 | 179 | print() 180 | print(args) 181 | -------------------------------------------------------------------------------- /anomaly_insert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | 6 | import numpy as np 7 | from scipy.stats import truncnorm 8 | from torch_sparse import SparseTensor 9 | 10 | from models.data import BipartiteData 11 | 12 | from typing import Tuple, Union 13 | 14 | # %% features outliers 15 | 16 | # features outside confidence interval 17 | def outside_cofidence_interval( 18 | x: torch.Tensor, prop_sample=0.1, prop_feat=0.3, std_cutoff=3.0, mu=None, sigm=None 19 | ): 20 | n, m = x.shape 21 | ns = int(np.ceil(prop_sample * n)) 22 | ms = int(np.ceil(prop_feat * m)) 23 | 24 | # random outlier from truncated normal 25 | left_side = truncnorm.rvs(-np.inf, -std_cutoff, size=ns * ms) 26 | right_side = truncnorm.rvs(std_cutoff, np.inf, size=ns * ms) 27 | lr_flag = np.random.randint(2, size=ns * ms) 28 | random_outliers = lr_flag * left_side + (1 - lr_flag) * right_side 29 | 30 | # determine which sample & features that are randomized 31 | feat_idx = np.random.rand(ns, m).argsort(axis=1)[:, :ms] 32 | sample_idx = np.random.choice(n, ns, replace=False) 33 | row_idx = np.tile(sample_idx[:, None], (1, ms)).flatten() 34 | col_idx = feat_idx.flatten() 35 | 36 | # calculate mean and variance 37 | xr = x.cpu().numpy() 38 | if mu is None: 39 | mu = xr.mean(axis=0) 40 | if sigm is None: 41 | sigm = xr.std(axis=0) 42 | 43 | # replace the value with outliers 44 | random_outliers = random_outliers * sigm[col_idx] + mu[col_idx] 45 | xr[(row_idx, col_idx)] = random_outliers 46 | 47 | # anomaly 48 | anomaly_label = torch.zeros(n).long() 49 | anomaly_label[sample_idx] = 1 50 | 51 | return torch.Tensor(xr), anomaly_label, row_idx, col_idx 52 | 53 | 54 | # add scaled gaussian noise 55 | def scaled_gaussian_noise( 56 | x: torch.Tensor, scale=3.0, min_dist_rel=3.0, filter=True, mu=None, sigm=None 57 | ): 58 | 59 | # calculate mean and variance 60 | if mu is None: 61 | mu = x.mean(dim=0) 62 | if sigm is None: 63 | sigm = x.std(dim=0) 64 | 65 | # noise 66 | noise = torch.randn(x.shape) * sigm * scale 67 | outlier = x + noise 68 | closest_dist = torch.cdist(outlier, x, p=1).min(dim=1)[0] 69 | if filter: 70 | anomaly_label = (closest_dist / x.shape[1] > min_dist_rel).long() 71 | # replace the value with outliers 72 | xr = anomaly_label[:, None] * outlier + (1 - anomaly_label[:, None]) * x 73 | else: 74 | anomaly_label = torch.ones(x.shape[0]).long() 75 | xr = outlier 76 | 77 | return xr, anomaly_label 78 | 79 | 80 | # %% structure outliers 81 | def dense_block( 82 | adj: SparseTensor, 83 | xe: torch.Tensor, 84 | ye=None, 85 | num_nodes: Union[int, Tuple[int, int]] = 5, 86 | num_group: int = 2, 87 | connected_prop=1.0, 88 | feature_anomaly=False, 89 | feature_anomaly_type="outside_ci", 90 | **kwargs, 91 | ): 92 | 93 | n, m = adj.sparse_sizes() 94 | ne = xe.shape[0] 95 | 96 | if isinstance(num_nodes, int): 97 | num_nodes = (num_nodes, num_nodes) 98 | 99 | row = adj.storage.row() 100 | col = adj.storage.col() 101 | ids = torch.stack([row, col]) 102 | 103 | outlier_row = torch.zeros(0).long() 104 | outlier_col = torch.zeros(0).long() 105 | 106 | for i in range(num_group): 107 | rid = np.random.choice(n, num_nodes[0], replace=False) 108 | cid = np.random.choice(m, num_nodes[1], replace=False) 109 | 110 | # all nodes are connected 111 | rows_id = torch.tensor(np.tile(rid[:, None], (1, num_nodes[1])).flatten()) 112 | cols_id = torch.tensor(np.tile(cid, num_nodes[0])) 113 | 114 | # partially dense connection 115 | if connected_prop < 1.0: 116 | n_connected = rows_id.shape[0] 117 | n_taken = int(np.ceil(connected_prop * n_connected)) 118 | taken_id = np.random.choice(n_connected, n_taken, replace=False) 119 | 120 | rows_id = rows_id[taken_id] 121 | cols_id = cols_id[taken_id] 122 | 123 | # add to the graph 124 | outlier_row = torch.cat([outlier_row, rows_id]) 125 | outlier_col = torch.cat([outlier_col, cols_id]) 126 | 127 | # only unique ids 128 | outlier_ids = torch.stack([outlier_row, outlier_col]).unique(dim=1) 129 | 130 | # find additional ids that is not in the current adj 131 | ids_all, inv, count = torch.cat([ids, outlier_ids], dim=1).unique( 132 | dim=1, return_counts=True, return_inverse=True 133 | ) 134 | ids_duplicate = ids_all[:, count > 1] 135 | ids_2, count_2 = torch.cat([outlier_ids, ids_duplicate], dim=1).unique( 136 | dim=1, return_counts=True 137 | ) 138 | ids_additional = ids_2[:, count_2 == 1] 139 | 140 | # anomalous label for the original 141 | label_orig = (count[inv][:ne] > 1).long() 142 | 143 | ## features 144 | n_add = ids_additional.shape[1] 145 | # random features for the new edges 146 | add_ids = np.random.choice(ne, n_add, replace=False) 147 | xe_add = xe[add_ids, :] 148 | 149 | # inject feature anomaly 150 | xe2 = xe.clone() 151 | if feature_anomaly: 152 | mu = xe.mean(dim=0).numpy() 153 | sigm = xe.std(dim=0).numpy() 154 | kwargs["mu"] = mu 155 | kwargs["sigm"] = sigm 156 | 157 | if feature_anomaly_type == "outside_ci": 158 | kwargs["prop_sample"] = 1.0 159 | xe_add = outside_cofidence_interval(xe_add, **kwargs)[0] 160 | if label_orig.sum() > 0: 161 | xe2[label_orig == 1, :] = outside_cofidence_interval( 162 | xe[label_orig == 1, :], **kwargs 163 | )[0] 164 | else: 165 | xe2 = xe 166 | elif feature_anomaly_type == "scaled_gaussian": 167 | kwargs["filter"] = False 168 | xe_add = scaled_gaussian_noise(xe_add, **kwargs)[0] 169 | if label_orig.sum() > 0: 170 | xe2[label_orig == 1, :] = scaled_gaussian_noise( 171 | xe[label_orig == 1, :], **kwargs 172 | )[0] 173 | else: 174 | xe2 = xe 175 | 176 | # combine with the previous label if given 177 | ye2 = label_orig if ye is None else torch.logical_or(ye, label_orig).long() 178 | 179 | # attach xe and label to value 180 | ids_cmb = torch.cat([ids, ids_additional], dim=1) 181 | xe_cmb = torch.cat([xe2, xe_add], dim=0) 182 | ye_cmb = torch.cat([ye2, torch.ones(n_add).long()]) 183 | label_cmb = torch.cat([label_orig, torch.ones(n_add).long()]) 184 | value_cmb = torch.cat([xe_cmb, ye_cmb[:, None], label_cmb[:, None]], dim=1) 185 | 186 | # get result 187 | adj_new = SparseTensor(row=ids_cmb[0], col=ids_cmb[1], value=value_cmb).coalesce() 188 | value_new = adj_new.storage.value() 189 | xe_new = value_new[:, :-2] 190 | ye_new = value_new[:, -2].long() 191 | label = value_new[:, -1].long() 192 | adj_new.storage._value = None 193 | 194 | return adj_new, xe_new, ye_new, label 195 | 196 | 197 | # %% graph, insert anomaly 198 | 199 | 200 | def inject_feature_anomaly( 201 | data: BipartiteData, 202 | node_anomaly=True, 203 | edge_anomaly=True, 204 | feature_anomaly_type="outside_ci", 205 | **kwargs, 206 | ): 207 | 208 | if node_anomaly: 209 | if feature_anomaly_type == "outside_ci": 210 | xu, yu2, _, _ = outside_cofidence_interval(data.xu, **kwargs) 211 | xv, yv2, _, _ = outside_cofidence_interval(data.xv, **kwargs) 212 | elif feature_anomaly_type == "scaled_gaussian": 213 | xu, yu2 = scaled_gaussian_noise(data.xu, **kwargs) 214 | xv, yv2 = scaled_gaussian_noise(data.xv, **kwargs) 215 | yu = torch.logical_or(data.yu, yu2).long() if hasattr(data, "yu") else yu2 216 | yv = torch.logical_or(data.yv, yv2).long() if hasattr(data, "yv") else yv2 217 | 218 | else: 219 | xu = data.xu 220 | xv = data.xv 221 | yu = data.yu if hasattr(data, "yu") else None 222 | yv = data.yv if hasattr(data, "yv") else None 223 | 224 | if edge_anomaly: 225 | if feature_anomaly_type == "outside_ci": 226 | xe, ye2, _, _ = outside_cofidence_interval(data.xe, **kwargs) 227 | elif feature_anomaly_type == "scaled_gaussian": 228 | xe, ye2 = scaled_gaussian_noise(data.xe, **kwargs) 229 | ye = torch.logical_or(data.ye, ye2).long() if hasattr(data, "ye") else ye2 230 | else: 231 | xe = data.xe 232 | ye = data.ye if hasattr(data, "ye") else None 233 | 234 | data_new = BipartiteData(data.adj, xu=xu, xv=xv, xe=xe, yu=yu, yv=yv, ye=ye) 235 | 236 | return data_new 237 | 238 | 239 | def inject_dense_block_anomaly(data: BipartiteData, **kwargs): 240 | kwargs["feature_anomaly"] = False 241 | ye = data.ye if hasattr(data, "ye") else None 242 | adj_new, xe_new, ye_new, label = dense_block(data.adj, data.xe, ye=ye, **kwargs) 243 | 244 | yu = torch.zeros(data.xu.shape[0]).long() 245 | yu[adj_new.storage.row()[label == 1].unique()] = 1 246 | 247 | yv = torch.zeros(data.xv.shape[0]).long() 248 | yv[adj_new.storage.col()[label == 1].unique()] = 1 249 | 250 | data_new = BipartiteData(adj_new, xu=data.xu, xv=data.xv, xe=xe_new) 251 | data_new.ye = ye_new 252 | data_new.yu = torch.logical_or(data.yu, yu).long() if hasattr(data, "yu") else yu 253 | data_new.yv = torch.logical_or(data.yv, yv).long() if hasattr(data, "yv") else yv 254 | 255 | return data_new 256 | 257 | 258 | def inject_dense_block_and_feature_anomaly( 259 | data: BipartiteData, node_feature_anomaly=False, edge_feature_anomaly=True, **kwargs 260 | ): 261 | 262 | kwargs["feature_anomaly"] = edge_feature_anomaly 263 | if "feature_anomaly_type" not in kwargs: 264 | kwargs["feature_anomaly_type"] = "outside_ci" 265 | 266 | ye = data.ye if hasattr(data, "ye") else None 267 | adj_new, xe_new, ye_new, label = dense_block(data.adj, data.xe, ye=ye, **kwargs) 268 | 269 | yu = torch.zeros(data.xu.shape[0]).long() 270 | yu[adj_new.storage.row()[label == 1].unique()] = 1 271 | 272 | yv = torch.zeros(data.xv.shape[0]).long() 273 | yv[adj_new.storage.col()[label == 1].unique()] = 1 274 | 275 | # also node feature anomaly 276 | if node_feature_anomaly: 277 | 278 | # args 279 | kw2 = {} 280 | 281 | # xu 282 | xu = data.xu 283 | mu = xu.mean(dim=0).numpy() 284 | sigm = xu.std(dim=0).numpy() 285 | kw2["mu"] = mu 286 | kw2["sigm"] = sigm 287 | 288 | if kwargs["feature_anomaly_type"] == "outside_ci": 289 | kw2["prop_sample"] = 1.0 290 | if "prop_feat" in kwargs: 291 | kw2["prop_feat"] = kwargs["prop_feat"] 292 | if "std_cutoff" in kwargs: 293 | kw2["std_cutoff"] = kwargs["std_cutoff"] 294 | xu_new = xu.clone() 295 | xu_new[yu == 1, :] = outside_cofidence_interval(xu[yu == 1, :], **kw2)[0] 296 | elif kwargs["feature_anomaly_type"] == "scaled_gaussian": 297 | kw2["filter"] = False 298 | if "scale" in kwargs: 299 | kw2["scale"] = kwargs["scale"] 300 | if "min_dist_rel" in kwargs: 301 | kw2["min_dist_rel"] = kwargs["min_dist_rel"] 302 | xu_new = xu.clone() 303 | xu_new[yu == 1, :] = scaled_gaussian_noise(xu[yu == 1, :], **kw2)[0] 304 | 305 | # xv 306 | xv = data.xv 307 | mu = xv.mean(dim=0).numpy() 308 | sigm = xv.std(dim=0).numpy() 309 | kw2["mu"] = mu 310 | kw2["sigm"] = sigm 311 | 312 | if kwargs["feature_anomaly_type"] == "outside_ci": 313 | kw2["prop_sample"] = 1.0 314 | if "prop_feat" in kwargs: 315 | kw2["prop_feat"] = kwargs["prop_feat"] 316 | if "std_cutoff" in kwargs: 317 | kw2["std_cutoff"] = kwargs["std_cutoff"] 318 | xv_new = xv.clone() 319 | xv_new[yv == 1, :] = outside_cofidence_interval(xv[yv == 1, :], **kw2)[0] 320 | elif kwargs["feature_anomaly_type"] == "scaled_gaussian": 321 | kw2["filter"] = False 322 | if "scale" in kwargs: 323 | kw2["scale"] = kwargs["scale"] 324 | if "min_dist_rel" in kwargs: 325 | kw2["min_dist_rel"] = kwargs["min_dist_rel"] 326 | xv_new = xv.clone() 327 | xv_new[yv == 1, :] = scaled_gaussian_noise(xv[yv == 1, :], **kw2)[0] 328 | 329 | # data 330 | data_new = BipartiteData(adj_new, xu=xu_new, xv=xv_new, xe=xe_new) 331 | data_new.ye = ye_new 332 | data_new.yu = ( 333 | torch.logical_or(data.yu, yu).long() if hasattr(data, "yu") else yu 334 | ) 335 | data_new.yv = ( 336 | torch.logical_or(data.yv, yv).long() if hasattr(data, "yv") else yv 337 | ) 338 | 339 | else: 340 | data_new = BipartiteData(adj_new, xu=data.xu, xv=data.xv, xe=xe_new) 341 | data_new.ye = ye_new 342 | data_new.yu = ( 343 | torch.logical_or(data.yu, yu).long() if hasattr(data, "yu") else yu 344 | ) 345 | data_new.yv = ( 346 | torch.logical_or(data.yv, yv).long() if hasattr(data, "yv") else yv 347 | ) 348 | 349 | return data_new 350 | 351 | 352 | # %% random anomaly 353 | 354 | 355 | def choose(r, choices, thresholds): 356 | i = 0 357 | cm = thresholds[i] 358 | while i < len(choices): 359 | if r <= cm + 1e-9: 360 | selected = i 361 | break 362 | else: 363 | i += 1 364 | if i < len(choices): 365 | cm += thresholds[i] 366 | else: 367 | selected = len(choices) - 1 368 | break 369 | 370 | return choices[selected] 371 | 372 | 373 | def inject_random_block_anomaly( 374 | data: BipartiteData, 375 | num_group=40, 376 | num_nodes_range=(1, 12), 377 | num_nodes_range2=None, 378 | **kwargs, 379 | ): 380 | 381 | block_anomalies = ["full_dense_block", "partial_full_dense_block"] # , 'none'] 382 | feature_anomalies = ["outside_ci", "scaled_gaussian", "none"] 383 | node_edge_feat_anomalies = ["node_only", "edge_only", "node_edge"] 384 | 385 | block_anomalies_weight = [0.2, 0.8] # , 0.1] 386 | feature_anomalies_weight = [0.5, 0.4, 0.1] 387 | node_edge_feat_anomalies_weight = [0.1, 0.3, 0.6] 388 | 389 | data_new = BipartiteData(data.adj, xu=data.xu, xv=data.xv, xe=data.xe) 390 | 391 | # random anomaly 392 | for itg in range(num_group): 393 | 394 | print(f"it {itg}: ", end="") 395 | 396 | rnd = torch.rand(3) 397 | block_an = choose(rnd[0], block_anomalies, block_anomalies_weight) 398 | feature_an = choose(rnd[1], feature_anomalies, feature_anomalies_weight) 399 | node_edge_an = choose( 400 | rnd[2], node_edge_feat_anomalies, node_edge_feat_anomalies_weight 401 | ) 402 | lr, rr, mr = ( 403 | num_nodes_range[0], 404 | num_nodes_range[1], 405 | num_nodes_range[0] + num_nodes_range[1] / 2, 406 | ) 407 | if num_nodes_range2 is not None: 408 | nn1 = int( 409 | np.minimum( 410 | np.maximum(lr, (torch.randn(1).item() * np.sqrt(mr)) + mr), rr + 1 411 | ) 412 | ) 413 | lr2, rr2, mr2 = ( 414 | num_nodes_range2[0], 415 | num_nodes_range2[1], 416 | num_nodes_range2[0] + num_nodes_range2[1] / 2, 417 | ) 418 | nn2 = int( 419 | np.minimum( 420 | np.maximum(lr2, (torch.randn(1).item() * np.sqrt(mr2)) + mr2), 421 | rr2 + 1, 422 | ) 423 | ) 424 | num_nodes = (nn1, nn2) 425 | else: 426 | num_nodes = int( 427 | np.minimum( 428 | np.maximum(lr, (torch.randn(1).item() * np.sqrt(mr)) + mr), rr + 1 429 | ) 430 | ) 431 | 432 | ## setup kwargs 433 | connected_prop = 1.0 434 | if block_an == "partial_full_dense_block": 435 | connected_prop = np.minimum( 436 | np.maximum(0.2, (torch.randn(1).item() / 4) + 0.5), 1.0 437 | ) 438 | 439 | prop_feat = np.minimum(np.maximum(0.1, (torch.randn(1).item() / 8) + 0.3), 0.9) 440 | std_cutoff = np.maximum(2.0, torch.randn(1).item() + 3.0) 441 | scale = np.maximum(2.0, torch.randn(1).item() + 3.0) 442 | 443 | ## inject anomaly 444 | node_feature_anomaly = None 445 | if block_an != "none" and feature_an != "none": 446 | node_feature_anomaly = False if node_edge_an == "edge_only" else True 447 | edge_feature_anomaly = False if node_edge_an == "node_only" else True 448 | 449 | if feature_an == "outside_ci": 450 | data_new = inject_dense_block_and_feature_anomaly( 451 | data_new, 452 | node_feature_anomaly, 453 | edge_feature_anomaly, 454 | num_group=1, 455 | num_nodes=num_nodes, 456 | connected_prop=connected_prop, 457 | feature_anomaly_type="outside_ci", 458 | prop_feat=prop_feat, 459 | std_cutoff=std_cutoff, 460 | ) 461 | 462 | elif feature_an == "scaled_gaussian": 463 | data_new = inject_dense_block_and_feature_anomaly( 464 | data_new, 465 | node_feature_anomaly, 466 | edge_feature_anomaly, 467 | num_group=1, 468 | num_nodes=num_nodes, 469 | connected_prop=connected_prop, 470 | feature_anomaly_type="scaled_gaussian", 471 | scale=scale, 472 | ) 473 | 474 | elif block_an != "none" and feature_an == "none": 475 | data_new = inject_dense_block_anomaly( 476 | data_new, 477 | num_group=1, 478 | num_nodes=num_nodes, 479 | connected_prop=connected_prop, 480 | ) 481 | 482 | elif block_an == "none" and feature_an != "none": 483 | node_anomaly = False if node_edge_an == "edge_only" else True 484 | edge_anomaly = False if node_edge_an == "node_only" else True 485 | 486 | if feature_an == "outside_ci": 487 | data_new = inject_feature_anomaly( 488 | data_new, 489 | node_anomaly, 490 | edge_anomaly, 491 | feature_anomaly_type="outside_ci", 492 | prop_feat=prop_feat, 493 | std_cutoff=std_cutoff, 494 | ) 495 | 496 | elif feature_an == "scaled_gaussian": 497 | data_new = inject_feature_anomaly( 498 | data_new, 499 | node_anomaly, 500 | edge_anomaly, 501 | feature_anomaly_type="scaled_gaussian", 502 | scale=scale, 503 | ) 504 | 505 | print( 506 | f"affected: yu = {data_new.yu.sum()}, yv = {data_new.yv.sum()}, ye = {data_new.ye.sum()} ", 507 | end="", 508 | ) 509 | print( 510 | f"[{block_an}:{connected_prop:.2f},{feature_an},{num_nodes},{node_feature_anomaly}]" 511 | ) 512 | 513 | return data_new 514 | -------------------------------------------------------------------------------- /anomalydae_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import sys 5 | 6 | from sklearn.metrics import roc_curve, precision_recall_curve, auc 7 | 8 | from data_finefoods import load_graph 9 | 10 | import argparse 11 | import os 12 | 13 | import torch 14 | from torch_geometric.data import Data 15 | from torch_scatter import scatter 16 | from utils.seed import seed_all 17 | 18 | # train a detector 19 | from pygod.models import AnomalyDAE 20 | 21 | # %% args 22 | 23 | parser = argparse.ArgumentParser(description="AnomalyDAE") 24 | parser.add_argument("--name", type=str, default="wikipedia_anomaly", help="name") 25 | parser.add_argument( 26 | "--key", type=str, default="graph_anomaly_list", help="key to the data" 27 | ) 28 | parser.add_argument("--id", type=int, default=0, help="id to the data") 29 | parser.add_argument("--n-epoch", type=int, default=200, help="number of epoch") 30 | parser.add_argument( 31 | "--num-neighbors", type=int, default=-1, help="number of neighbors for node" 32 | ) 33 | parser.add_argument("--batch-size", type=int, default=0, help="batch size") 34 | parser.add_argument("--alpha", type=float, default=0.8, help="balance parameter") 35 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") 36 | parser.add_argument("--gpu", type=int, default=0, help="gpu number") 37 | 38 | args1 = vars(parser.parse_args()) 39 | 40 | args2 = { 41 | "seed": 0, 42 | "hidden_channels": 32, 43 | "dropout_prob": 0.0, 44 | } 45 | 46 | args = {**args1, **args2} 47 | 48 | seed_all(args["seed"]) 49 | 50 | result_dir = "results/" 51 | 52 | # %% data 53 | data = load_graph(args["name"], args["key"], args["id"]) 54 | 55 | u_ch = data.xu.shape[1] 56 | v_ch = data.xv.shape[1] 57 | e_ch = data.xe.shape[1] 58 | 59 | print( 60 | f"Data dimension: U node = {data.xu.shape}; V node = {data.xv.shape}; E edge = {data.xe.shape}; \n" 61 | ) 62 | 63 | # %% model 64 | 65 | xu, xv = data.xu, data.xv 66 | xe, adj = data.xe, data.adj 67 | yu, yv, ye = data.yu, data.yv, data.ye 68 | 69 | 70 | # %% to homogen 71 | nu = xu.shape[0] 72 | nv = xv.shape[0] 73 | nn = nu + nv 74 | 75 | # to homogen 76 | row_h = torch.cat([adj.storage.row(), adj.storage.col() + nu]) 77 | col_h = torch.cat([adj.storage.col() + nu, adj.storage.row()]) 78 | edge_index_h = torch.stack([row_h, col_h]) 79 | xuh = torch.cat( 80 | [ 81 | scatter(xe, adj.storage.row(), dim=0, reduce="max"), 82 | scatter(xe, adj.storage.row(), dim=0, reduce="mean"), 83 | ], 84 | dim=1, 85 | ) 86 | xvh = torch.cat( 87 | [ 88 | scatter(xe, adj.storage.col(), dim=0, reduce="max"), 89 | scatter(xe, adj.storage.col(), dim=0, reduce="mean"), 90 | ], 91 | dim=1, 92 | ) 93 | xh = torch.cat([xuh, xvh], dim=0) 94 | yh = torch.cat([yu, yv], dim=0) 95 | data_h = Data(x=xh, edge_index=edge_index_h, y=yh) 96 | 97 | # %% model 98 | 99 | device = torch.device(f'cuda:{args["gpu"]}' if torch.cuda.is_available() else "cpu") 100 | model = AnomalyDAE( 101 | embed_dim=args["hidden_channels"], 102 | out_dim=args["hidden_channels"], 103 | dropout=args["dropout_prob"], 104 | alpha=args["alpha"], 105 | epoch=args["n_epoch"], 106 | lr=args["lr"], 107 | verbose=True, 108 | gpu=args["gpu"], 109 | batch_size=args["batch_size"], 110 | num_neigh=args["num_neighbors"], 111 | ) 112 | 113 | print(args) 114 | print() 115 | 116 | 117 | def auc_eval(pred, y): 118 | 119 | rc_curve = roc_curve(y, pred) 120 | pr_curve = precision_recall_curve(y, pred) 121 | roc_auc = auc(rc_curve[0], rc_curve[1]) 122 | pr_auc = auc(pr_curve[1], pr_curve[0]) 123 | 124 | return roc_auc, pr_auc, rc_curve, pr_curve 125 | 126 | 127 | # %% run training 128 | 129 | model.fit(data_h, yh) 130 | score = model.decision_scores_ 131 | 132 | score_u = score[:nu] 133 | score_v = score[nu:] 134 | score_e_u = score_u[adj.storage.row().numpy()] 135 | score_e_v = score_v[adj.storage.col().numpy()] 136 | score_e = (score_e_u + score_e_v) / 2 137 | 138 | u_roc_auc, u_pr_auc, u_rc_curve, u_pr_curve = auc_eval(score_u, yu.numpy()) 139 | v_roc_auc, v_pr_auc, v_rc_curve, v_pr_curve = auc_eval(score_v, yv.numpy()) 140 | e_roc_auc, e_pr_auc, e_rc_curve, e_pr_curve = auc_eval(score_e, ye.numpy()) 141 | 142 | print( 143 | f"Eval | " 144 | + f"u auc-roc: {u_roc_auc:.4f}, v auc-roc: {v_roc_auc:.4f}, e auc-roc: {e_roc_auc:.4f} | " 145 | + f"u auc-pr {u_pr_auc:.4f}, v auc-pr {v_pr_auc:.4f}, e auc-pr {e_pr_auc:.4f}" 146 | ) 147 | 148 | auc_metrics = { 149 | "u_roc_auc": u_roc_auc, 150 | "u_pr_auc": u_pr_auc, 151 | "v_roc_auc": v_roc_auc, 152 | "v_pr_auc": v_pr_auc, 153 | "e_roc_auc": e_roc_auc, 154 | "e_pr_auc": e_pr_auc, 155 | "u_roc_curve": u_rc_curve, 156 | "u_pr_curve": u_pr_curve, 157 | "v_roc_curve": v_rc_curve, 158 | "v_pr_curve": v_pr_curve, 159 | "e_roc_curve": e_rc_curve, 160 | "e_pr_curve": e_pr_curve, 161 | } 162 | anomaly_score = {"score_u": score_u, "score_v": score_v, "score_e": score_e} 163 | 164 | model_stored = { 165 | "args": args, 166 | "auc_metrics": auc_metrics, 167 | "state_dict": model.model.state_dict(), 168 | } 169 | output_stored = {"args": args, "anomaly_score": anomaly_score} 170 | 171 | print("Saving current results...") 172 | torch.save( 173 | model_stored, 174 | os.path.join( 175 | result_dir, 176 | f"anomalydae-{args['name']}-{args['id']}-alpha-{args['alpha']}-model.th", 177 | ), 178 | ) 179 | torch.save( 180 | output_stored, 181 | os.path.join( 182 | result_dir, 183 | f"anomalydae-{args['name']}-{args['id']}-alpha-{args['alpha']}-output.th", 184 | ), 185 | ) 186 | 187 | 188 | print() 189 | print(args) 190 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /data_finefoods.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_sparse.tensor import SparseTensor 6 | 7 | import numpy as np 8 | from anomaly_insert import inject_random_block_anomaly 9 | 10 | from models.data import BipartiteData 11 | 12 | import torch 13 | from sklearn import preprocessing 14 | 15 | import pandas as pd 16 | 17 | from sentence_transformers import SentenceTransformer 18 | 19 | # %% 20 | 21 | 22 | def standardize(features: np.ndarray) -> np.ndarray: 23 | scaler = preprocessing.StandardScaler() 24 | z = scaler.fit_transform(features) 25 | return z 26 | 27 | 28 | def prepare_data(): 29 | model = SentenceTransformer("all-MiniLM-L6-v2") 30 | df = pd.read_csv(f"data/finefoods.csv") 31 | 32 | df["SummaryCharLen"] = df["Summary"].astype("str").apply(len) 33 | df["TextCharLen"] = df["Text"].astype("str").apply(len) 34 | df["Helpfulness"] = ( 35 | df["HelpfulnessNumerator"] / df["HelpfulnessDenominator"] 36 | ).fillna(0) 37 | 38 | df = df.iloc[:, 1:].sort_values(["ProductId", "UserId", "Time"]) 39 | dfu = df.groupby(["ProductId", "UserId"], as_index=False).last() 40 | 41 | df_product = dfu.groupby("ProductId", as_index=False).agg( 42 | user_count=("UserId", "count"), 43 | helpful_num_mean=("HelpfulnessNumerator", "mean"), 44 | helpful_num_sum=("HelpfulnessNumerator", "sum"), 45 | helpful_mean=("Helpfulness", "mean"), 46 | helpful_sum=("Helpfulness", "sum"), 47 | score_mean=("Score", "mean"), 48 | score_sum=("Score", "sum"), 49 | summary_len_mean=("SummaryCharLen", "mean"), 50 | summary_len_sum=("SummaryCharLen", "sum"), 51 | text_len_mean=("TextCharLen", "mean"), 52 | text_len_sum=("TextCharLen", "sum"), 53 | ) 54 | 55 | df_user = dfu.groupby("UserId", as_index=False).agg( 56 | product_count=("ProductId", "count"), 57 | helpful_num_mean=("HelpfulnessNumerator", "mean"), 58 | helpful_num_sum=("HelpfulnessNumerator", "sum"), 59 | helpful_mean=("Helpfulness", "mean"), 60 | helpful_sum=("Helpfulness", "sum"), 61 | score_mean=("Score", "mean"), 62 | score_sum=("Score", "sum"), 63 | summary_len_mean=("SummaryCharLen", "mean"), 64 | summary_len_sum=("SummaryCharLen", "sum"), 65 | text_len_mean=("TextCharLen", "mean"), 66 | text_len_sum=("TextCharLen", "sum"), 67 | ) 68 | 69 | df_user.to_csv(f"data/finefoods-user.csv") 70 | df_product.to_csv(f"data/finefoods-product.csv") 71 | 72 | sentences = dfu["Text"].astype("str").to_numpy() 73 | embeddings = model.encode(sentences) 74 | cols = [f"v{i}" for i in range(embeddings.shape[1])] 75 | df_review = pd.concat( 76 | [dfu[["ProductId", "UserId"]], pd.DataFrame(embeddings, columns=cols)], axis=1 77 | ) 78 | 79 | df_review.to_csv(f"data/finefoods-review.csv") 80 | 81 | 82 | def create_graph(): 83 | 84 | df_user = pd.read_csv("data/finefoods-user.csv") 85 | df_product = pd.read_csv("data/finefoods-product.csv") 86 | df_review = pd.read_csv("data/finefoods-review.csv") 87 | 88 | df_user["uid"] = df_user.index 89 | df_product["pid"] = df_product.index 90 | 91 | df_user_id = df_user[["UserId", "uid"]] 92 | df_product_id = df_product[["ProductId", "pid"]] 93 | 94 | df_review_2 = df_review.merge( 95 | df_user_id, 96 | on="UserId", 97 | ).merge(df_product_id, on="ProductId") 98 | df_review_2 = df_review_2.sort_values(["uid", "pid"]) 99 | 100 | uid = torch.tensor(df_review_2["uid"].to_numpy()) 101 | pid = torch.tensor(df_review_2["pid"].to_numpy()) 102 | 103 | adj = SparseTensor(row=uid, col=pid) 104 | edge_attr = torch.tensor(standardize(df_review_2.iloc[:, 3:-2].to_numpy())).float() 105 | 106 | user_attr = torch.tensor(standardize(df_user.iloc[:, 2:-1].to_numpy())).float() 107 | product_attr = torch.tensor( 108 | standardize(df_product.iloc[:, 2:-1].to_numpy()) 109 | ).float() 110 | 111 | data = BipartiteData(adj, xu=user_attr, xv=product_attr, xe=edge_attr) 112 | 113 | return data 114 | 115 | 116 | def store_graph(name: str, dataset): 117 | torch.save(dataset, f"storage/{name}.pt") 118 | 119 | 120 | def load_graph(name: str, key: str, id=None): 121 | if id == None: 122 | data = torch.load(f"storage/{name}.pt") 123 | return data[key] 124 | else: 125 | data = torch.load(f"storage/{name}.pt") 126 | return data[key][id] 127 | 128 | 129 | def synth_random(): 130 | # generate nd store data 131 | import argparse 132 | 133 | parser = argparse.ArgumentParser(description="GraphBEAN") 134 | parser.add_argument("--name", type=str, default="finefoods_anomaly", help="name") 135 | parser.add_argument("--n-graph", type=int, default=5, help="n graph") 136 | 137 | args = vars(parser.parse_args()) 138 | 139 | prepare_data() 140 | graph = create_graph() 141 | store_graph("finefoods-graph", graph) 142 | # graph = torch.load(f'storage/finefoods-graph.pt') 143 | 144 | graph_anomaly_list = [] 145 | for i in range(args["n_graph"]): 146 | print(f"GRAPH ANOMALY {i} >>>>>>>>>>>>>>") 147 | graph_multi_dense = inject_random_block_anomaly( 148 | graph, num_group=100, num_nodes_range=(1, 20) 149 | ) 150 | graph_anomaly_list.append(graph_multi_dense) 151 | print() 152 | 153 | dataset = {"args": args, "graph": graph, "graph_anomaly_list": graph_anomaly_list} 154 | 155 | store_graph(args["name"], dataset) 156 | 157 | 158 | if __name__ == "__main__": 159 | synth_random() 160 | -------------------------------------------------------------------------------- /data_finefoods_small.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_sparse.tensor import SparseTensor 6 | 7 | import numpy as np 8 | from anomaly_insert import inject_random_block_anomaly 9 | 10 | from models.data import BipartiteData 11 | 12 | import torch 13 | from sklearn import preprocessing 14 | 15 | import pandas as pd 16 | 17 | # %% 18 | 19 | 20 | def standardize(features: np.ndarray) -> np.ndarray: 21 | scaler = preprocessing.StandardScaler() 22 | z = scaler.fit_transform(features) 23 | return z 24 | 25 | 26 | def sample_data(): 27 | df_user = pd.read_csv(f"data/finefoods-user.csv") 28 | df_product = pd.read_csv(f"data/finefoods-product.csv") 29 | df_review = pd.read_csv(f"data/finefoods-review.csv") 30 | 31 | pc = np.log10(df_user["product_count"].to_numpy()) + 1 32 | user_weight = pc / pc.sum() 33 | 34 | uc = np.log10(df_product["user_count"].to_numpy()) + 1 35 | product_weight = uc / uc.sum() 36 | 37 | user_nums = np.random.choice(df_user.shape[0], 24000, replace=False, p=user_weight) 38 | user_ids = df_user["UserId"][user_nums] 39 | 40 | product_nums = np.random.choice( 41 | df_product.shape[0], 12000, replace=False, p=product_weight 42 | ) 43 | product_ids = df_product["ProductId"][product_nums] 44 | 45 | df_review_chosen = df_review[ 46 | df_review["ProductId"].isin(product_ids) & df_review["UserId"].isin(user_ids) 47 | ].iloc[:, 1:] 48 | df_user_chosen = df_user[ 49 | df_user["UserId"].isin(df_review_chosen["UserId"].unique()) 50 | ].iloc[:, 1:] 51 | df_product_chosen = df_product[ 52 | df_product["ProductId"].isin(df_review_chosen["ProductId"].unique()) 53 | ].iloc[:, 1:] 54 | 55 | df_user_chosen.to_csv(f"data/finefoods-small-user.csv") 56 | df_product_chosen.to_csv(f"data/finefoods-small-product.csv") 57 | df_review_chosen.to_csv(f"data/finefoods-small-review.csv") 58 | 59 | 60 | def create_graph(): 61 | 62 | df_user = pd.read_csv("data/finefoods-small-user.csv") 63 | df_product = pd.read_csv("data/finefoods-small-product.csv") 64 | df_review = pd.read_csv("data/finefoods-small-review.csv") 65 | 66 | df_user["uid"] = df_user.index 67 | df_product["pid"] = df_product.index 68 | 69 | df_user_id = df_user[["UserId", "uid"]] 70 | df_product_id = df_product[["ProductId", "pid"]] 71 | 72 | df_review_2 = df_review.merge( 73 | df_user_id, 74 | on="UserId", 75 | ).merge(df_product_id, on="ProductId") 76 | df_review_2 = df_review_2.sort_values(["uid", "pid"]) 77 | 78 | uid = torch.tensor(df_review_2["uid"].to_numpy()) 79 | pid = torch.tensor(df_review_2["pid"].to_numpy()) 80 | 81 | adj = SparseTensor(row=uid, col=pid) 82 | edge_attr = torch.tensor(standardize(df_review_2.iloc[:, 3:-2].to_numpy())).float() 83 | 84 | user_attr = torch.tensor(standardize(df_user.iloc[:, 2:-1].to_numpy())).float() 85 | product_attr = torch.tensor( 86 | standardize(df_product.iloc[:, 2:-1].to_numpy()) 87 | ).float() 88 | 89 | data = BipartiteData(adj, xu=user_attr, xv=product_attr, xe=edge_attr) 90 | 91 | return data 92 | 93 | 94 | def store_graph(name: str, dataset): 95 | torch.save(dataset, f"storage/{name}.pt") 96 | 97 | 98 | def load_graph(name: str, key: str, id=None): 99 | if id == None: 100 | data = torch.load(f"storage/{name}.pt") 101 | return data[key] 102 | else: 103 | data = torch.load(f"storage/{name}.pt") 104 | return data[key][id] 105 | 106 | 107 | def synth_random(): 108 | # generate nd store data 109 | import argparse 110 | 111 | parser = argparse.ArgumentParser(description="GraphBEAN") 112 | parser.add_argument( 113 | "--name", type=str, default="finefoods-small_anomaly", help="name" 114 | ) 115 | parser.add_argument("--n-graph", type=int, default=10, help="n graph") 116 | 117 | args = vars(parser.parse_args()) 118 | 119 | sample_data() 120 | graph = create_graph() 121 | store_graph("finefoods-small-graph", graph) 122 | # graph = torch.load(f'storage/finefoods-small-graph.pt') 123 | 124 | graph_anomaly_list = [] 125 | for i in range(args["n_graph"]): 126 | print(f"GRAPH ANOMALY {i} >>>>>>>>>>>>>>") 127 | graph_multi_dense = inject_random_block_anomaly( 128 | graph, num_group=20, num_nodes_range=(1, 12) 129 | ) 130 | graph_anomaly_list.append(graph_multi_dense) 131 | print() 132 | 133 | dataset = {"args": args, "graph": graph, "graph_anomaly_list": graph_anomaly_list} 134 | 135 | store_graph(args["name"], dataset) 136 | 137 | 138 | if __name__ == "__main__": 139 | synth_random() 140 | -------------------------------------------------------------------------------- /data_movies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_sparse.tensor import SparseTensor 6 | 7 | import numpy as np 8 | from anomaly_insert import inject_random_block_anomaly 9 | 10 | from models.data import BipartiteData 11 | 12 | import torch 13 | from sklearn import preprocessing 14 | 15 | import pandas as pd 16 | 17 | from sentence_transformers import SentenceTransformer 18 | 19 | 20 | # %% 21 | 22 | 23 | def standardize(features: np.ndarray) -> np.ndarray: 24 | scaler = preprocessing.StandardScaler() 25 | z = scaler.fit_transform(features) 26 | return z 27 | 28 | 29 | def prepare_data(): 30 | model = SentenceTransformer("all-MiniLM-L6-v2") 31 | df = pd.read_csv(f"data/movies.csv") 32 | 33 | df["summary_char_len"] = df["summary"].astype("str").apply(len) 34 | df["text_char_len"] = df["text"].astype("str").apply(len) 35 | df["helpfulness"] = ( 36 | df["helpfulness_numerator"] / df["helpfulness_denominator"] 37 | ).fillna(0) 38 | 39 | df = df.sort_values(["product_id", "user_id", "time"]) 40 | dfu = df.groupby(["product_id", "user_id"], as_index=False).last() 41 | 42 | df_product = dfu.groupby("product_id", as_index=False).agg( 43 | user_count=("user_id", "count"), 44 | helpful_num_mean=("helpfulness_numerator", "mean"), 45 | helpful_num_sum=("helpfulness_numerator", "sum"), 46 | helpful_mean=("helpfulness", "mean"), 47 | helpful_sum=("helpfulness", "sum"), 48 | score_mean=("score", "mean"), 49 | score_sum=("score", "sum"), 50 | summary_len_mean=("summary_char_len", "mean"), 51 | summary_len_sum=("summary_char_len", "sum"), 52 | text_len_mean=("text_char_len", "mean"), 53 | text_len_sum=("text_char_len", "sum"), 54 | ) 55 | 56 | df_user = dfu.groupby("user_id", as_index=False).agg( 57 | product_count=("product_id", "count"), 58 | helpful_num_mean=("helpfulness_numerator", "mean"), 59 | helpful_num_sum=("helpfulness_numerator", "sum"), 60 | helpful_mean=("helpfulness", "mean"), 61 | helpful_sum=("helpfulness", "sum"), 62 | score_mean=("score", "mean"), 63 | score_sum=("score", "sum"), 64 | summary_len_mean=("summary_char_len", "mean"), 65 | summary_len_sum=("summary_char_len", "sum"), 66 | text_len_mean=("text_char_len", "mean"), 67 | text_len_sum=("text_char_len", "sum"), 68 | ) 69 | 70 | df_user.to_csv(f"data/movies-user.csv") 71 | df_product.to_csv(f"data/movies-product.csv") 72 | 73 | sentences = dfu["text"].astype("str").to_numpy() 74 | embeddings = model.encode(sentences) 75 | 76 | np.save(f"data/movies-embeddings.npy", embeddings) 77 | dfu[["product_id", "user_id"]].to_csv(f"data/movies-ids.csv") 78 | 79 | 80 | def create_graph(): 81 | 82 | df_user = pd.read_csv("data/movies-user.csv") 83 | df_product = pd.read_csv("data/movies-product.csv") 84 | df_review_id = pd.read_csv("data/movies-ids.csv") 85 | embeddings = np.load("data/movies-embeddings.npy") 86 | 87 | df_user["uid"] = df_user.index 88 | df_product["pid"] = df_product.index 89 | 90 | df_user_id = df_user[["user_id", "uid"]] 91 | df_product_id = df_product[["product_id", "pid"]] 92 | 93 | cols = [f"v{i}" for i in range(embeddings.shape[1])] 94 | df_review = pd.concat( 95 | [df_review_id, pd.DataFrame(embeddings, columns=cols)], axis=1 96 | ) 97 | 98 | df_review_2 = df_review.merge( 99 | df_user_id, 100 | on="user_id", 101 | ).merge(df_product_id, on="product_id") 102 | df_review_2 = df_review_2.sort_values(["uid", "pid"]) 103 | 104 | uid = torch.tensor(df_review_2["uid"].to_numpy()) 105 | pid = torch.tensor(df_review_2["pid"].to_numpy()) 106 | 107 | adj = SparseTensor(row=uid, col=pid) 108 | edge_attr = torch.tensor(standardize(df_review_2.iloc[:, 3:-2].to_numpy())).float() 109 | 110 | user_attr = torch.tensor(standardize(df_user.iloc[:, 2:-1].to_numpy())).float() 111 | product_attr = torch.tensor( 112 | standardize(df_product.iloc[:, 2:-1].to_numpy()) 113 | ).float() 114 | 115 | data = BipartiteData(adj, xu=user_attr, xv=product_attr, xe=edge_attr) 116 | 117 | return data 118 | 119 | 120 | def store_graph(name: str, dataset): 121 | torch.save(dataset, f"storage/{name}.pt") 122 | 123 | 124 | def load_graph(name: str, key: str, id=None): 125 | if id == None: 126 | data = torch.load(f"storage/{name}.pt") 127 | return data[key] 128 | else: 129 | data = torch.load(f"storage/{name}.pt") 130 | return data[key][id] 131 | 132 | 133 | def synth_random(): 134 | # generate nd store data 135 | import argparse 136 | 137 | parser = argparse.ArgumentParser(description="GraphBEAN") 138 | parser.add_argument("--name", type=str, default="movies_anomaly", help="name") 139 | parser.add_argument("--n-graph", type=int, default=2, help="n graph") 140 | 141 | args = vars(parser.parse_args()) 142 | 143 | prepare_data() 144 | graph = create_graph() 145 | store_graph("movies-graph", graph) 146 | # graph = torch.load(f'storage/movies-graph.pt') 147 | print(graph) 148 | 149 | graph_anomaly_list = [] 150 | for i in range(args["n_graph"]): 151 | print(f"GRAPH ANOMALY {i} >>>>>>>>>>>>>>") 152 | graph_multi_dense = inject_random_block_anomaly( 153 | graph, num_group=100, num_nodes_range=(1, 20) 154 | ) 155 | graph_anomaly_list.append(graph_multi_dense) 156 | print() 157 | 158 | dataset = {"args": args, "graph": graph, "graph_anomaly_list": graph_anomaly_list} 159 | 160 | store_graph(args["name"], dataset) 161 | 162 | 163 | if __name__ == "__main__": 164 | synth_random() 165 | -------------------------------------------------------------------------------- /data_movies_small.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_sparse.tensor import SparseTensor 6 | 7 | import numpy as np 8 | from anomaly_insert import inject_random_block_anomaly 9 | 10 | from models.data import BipartiteData 11 | 12 | import torch 13 | from sklearn import preprocessing 14 | 15 | import pandas as pd 16 | 17 | # %% 18 | 19 | 20 | def standardize(features: np.ndarray) -> np.ndarray: 21 | scaler = preprocessing.StandardScaler() 22 | z = scaler.fit_transform(features) 23 | return z 24 | 25 | 26 | def sample_data(): 27 | df_user = pd.read_csv(f"data/movies-user.csv") 28 | df_product = pd.read_csv(f"data/movies-product.csv") 29 | df_review = pd.read_csv(f"data/movies-review.csv") 30 | 31 | pc = np.log10(df_user["product_count"].to_numpy()) + 1 32 | user_weight = pc / pc.sum() 33 | 34 | uc = np.log10(df_product["user_count"].to_numpy()) + 1 35 | product_weight = uc / uc.sum() 36 | 37 | user_nums = np.random.choice(df_user.shape[0], 28000, replace=False, p=user_weight) 38 | user_ids = df_user["user_id"][user_nums] 39 | 40 | product_nums = np.random.choice( 41 | df_product.shape[0], 14000, replace=False, p=product_weight 42 | ) 43 | product_ids = df_product["product_id"][product_nums] 44 | 45 | df_review_chosen = df_review[ 46 | df_review["product_id"].isin(product_ids) & df_review["user_id"].isin(user_ids) 47 | ].iloc[:, 1:] 48 | df_user_chosen = df_user[ 49 | df_user["user_id"].isin(df_review_chosen["user_id"].unique()) 50 | ].iloc[:, 1:] 51 | df_product_chosen = df_product[ 52 | df_product["product_id"].isin(df_review_chosen["product_id"].unique()) 53 | ].iloc[:, 1:] 54 | 55 | df_user_chosen.to_csv(f"data/movies-small-user.csv") 56 | df_product_chosen.to_csv(f"data/movies-small-product.csv") 57 | df_review_chosen.to_csv(f"data/movies-small-review.csv") 58 | 59 | 60 | def create_graph(): 61 | 62 | df_user = pd.read_csv("data/movies-small-user.csv") 63 | df_product = pd.read_csv("data/movies-small-product.csv") 64 | df_review = pd.read_csv("data/movies-small-review.csv") 65 | 66 | df_user["uid"] = df_user.index 67 | df_product["pid"] = df_product.index 68 | 69 | df_user_id = df_user[["user_id", "uid"]] 70 | df_product_id = df_product[["product_id", "pid"]] 71 | 72 | df_review_2 = df_review.merge( 73 | df_user_id, 74 | on="user_id", 75 | ).merge(df_product_id, on="product_id") 76 | df_review_2 = df_review_2.sort_values(["uid", "pid"]) 77 | 78 | uid = torch.tensor(df_review_2["uid"].to_numpy()) 79 | pid = torch.tensor(df_review_2["pid"].to_numpy()) 80 | 81 | adj = SparseTensor(row=uid, col=pid) 82 | edge_attr = torch.tensor(standardize(df_review_2.iloc[:, 3:-2].to_numpy())).float() 83 | 84 | user_attr = torch.tensor(standardize(df_user.iloc[:, 2:-1].to_numpy())).float() 85 | product_attr = torch.tensor( 86 | standardize(df_product.iloc[:, 2:-1].to_numpy()) 87 | ).float() 88 | 89 | data = BipartiteData(adj, xu=user_attr, xv=product_attr, xe=edge_attr) 90 | 91 | return data 92 | 93 | 94 | def store_graph(name: str, dataset): 95 | torch.save(dataset, f"storage/{name}.pt") 96 | 97 | 98 | def load_graph(name: str, key: str, id=None): 99 | if id == None: 100 | data = torch.load(f"storage/{name}.pt") 101 | return data[key] 102 | else: 103 | data = torch.load(f"storage/{name}.pt") 104 | return data[key][id] 105 | 106 | 107 | def synth_random(): 108 | # generate nd store data 109 | import argparse 110 | 111 | parser = argparse.ArgumentParser(description="GraphBEAN") 112 | parser.add_argument("--name", type=str, default="movies-small_anomaly", help="name") 113 | parser.add_argument("--n-graph", type=int, default=10, help="n graph") 114 | 115 | args = vars(parser.parse_args()) 116 | 117 | sample_data() 118 | graph = create_graph() 119 | store_graph("movies-small-graph", graph) 120 | # graph = torch.load(f'storage/movies-small-graph.pt') 121 | 122 | graph_anomaly_list = [] 123 | for i in range(args["n_graph"]): 124 | print(f"GRAPH ANOMALY {i} >>>>>>>>>>>>>>") 125 | graph_multi_dense = inject_random_block_anomaly( 126 | graph, num_group=20, num_nodes_range=(1, 12) 127 | ) 128 | graph_anomaly_list.append(graph_multi_dense) 129 | print() 130 | 131 | dataset = {"args": args, "graph": graph, "graph_anomaly_list": graph_anomaly_list} 132 | 133 | store_graph(args["name"], dataset) 134 | 135 | 136 | if __name__ == "__main__": 137 | synth_random() 138 | -------------------------------------------------------------------------------- /data_reddit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_sparse.tensor import SparseTensor 6 | 7 | import numpy as np 8 | from anomaly_insert import inject_random_block_anomaly 9 | 10 | from models.data import BipartiteData 11 | 12 | import torch 13 | from sklearn import preprocessing 14 | 15 | import pandas as pd 16 | 17 | # %% 18 | 19 | 20 | def standardize(features: np.ndarray) -> np.ndarray: 21 | scaler = preprocessing.StandardScaler() 22 | z = scaler.fit_transform(features) 23 | return z 24 | 25 | 26 | def prepare_data(): 27 | 28 | cols = ["user_id", "item_id", "timestamp", "state_label"] + [ 29 | f"v{i+1}" for i in range(172) 30 | ] 31 | df = pd.read_csv(f"data/wikipedia.csv", skiprows=1, names=cols) 32 | 33 | # edge 34 | cols_d = {"item_id": [("n_action", "count")]} 35 | for i in range(172): 36 | cols_d[f"v{i+1}"] = [(f"v{i+1}_mean", "mean"), (f"v{i+1}_max", "max")] 37 | 38 | df_edge = df.groupby(["user_id", "item_id"]).agg(cols_d) 39 | df_edge = df_edge.droplevel(axis=1, level=0).reset_index() 40 | df_edge.to_csv(f"data/reddit-edge.csv") 41 | 42 | # user 43 | cols_d = {"item_id": [("n_item", "nunique"), ("n_action", "count")]} 44 | for i in range(172): 45 | cols_d[f"v{i+1}"] = [(f"v{i+1}_mean", "mean")] 46 | 47 | df_user = df.groupby(["user_id"]).agg(cols_d) 48 | 49 | df_user = df_user.droplevel(axis=1, level=0).reset_index() 50 | df_user.to_csv(f"data/reddit-user.csv") 51 | 52 | # item 53 | cols_d = {"user_id": [("n_user", "nunique"), ("n_action", "count")]} 54 | for i in range(172): 55 | cols_d[f"v{i+1}"] = [(f"v{i+1}_mean", "mean")] 56 | 57 | df_item = df.groupby(["item_id"]).agg(cols_d) 58 | df_item = df_item.droplevel(axis=1, level=0).reset_index() 59 | df_item.to_csv(f"data/reddit-item.csv") 60 | 61 | 62 | def create_graph(): 63 | 64 | df_user = pd.read_csv("data/reddit-user.csv") 65 | df_item = pd.read_csv("data/reddit-item.csv") 66 | df_edge = pd.read_csv("data/reddit-edge.csv") 67 | 68 | df_user["uid"] = df_user.index 69 | df_item["iid"] = df_item.index 70 | 71 | df_user_id = df_user[["user_id", "uid"]] 72 | df_item_id = df_item[["item_id", "iid"]] 73 | 74 | df_edge_2 = df_edge.merge( 75 | df_user_id, 76 | on="user_id", 77 | ).merge(df_item_id, on="item_id") 78 | df_edge_2 = df_edge_2.sort_values(["uid", "iid"]) 79 | 80 | uid = torch.tensor(df_edge_2["uid"].to_numpy()) 81 | iid = torch.tensor(df_edge_2["iid"].to_numpy()) 82 | 83 | adj = SparseTensor(row=uid, col=iid) 84 | edge_attr = torch.tensor(standardize(df_edge_2.iloc[:, 3:-2].to_numpy())).float() 85 | 86 | user_attr = torch.tensor(standardize(df_user.iloc[:, 2:-1].to_numpy())).float() 87 | product_attr = torch.tensor(standardize(df_item.iloc[:, 2:-1].to_numpy())).float() 88 | 89 | data = BipartiteData(adj, xu=user_attr, xv=product_attr, xe=edge_attr) 90 | 91 | return data 92 | 93 | 94 | def store_graph(name: str, dataset): 95 | torch.save(dataset, f"storage/{name}.pt") 96 | 97 | 98 | def load_graph(name: str, key: str, id=None): 99 | if id == None: 100 | data = torch.load(f"storage/{name}.pt") 101 | return data[key] 102 | else: 103 | data = torch.load(f"storage/{name}.pt") 104 | return data[key][id] 105 | 106 | 107 | def synth_random(): 108 | # generate nd store data 109 | import argparse 110 | 111 | parser = argparse.ArgumentParser(description="GraphBEAN") 112 | parser.add_argument("--name", type=str, default="reddit_anomaly", help="name") 113 | parser.add_argument("--n-graph", type=int, default=10, help="n graph") 114 | 115 | args = vars(parser.parse_args()) 116 | 117 | prepare_data() 118 | graph = create_graph() 119 | store_graph("reddit-graph", graph) 120 | # graph = torch.load(f'storage/reddit-graph.pt') 121 | 122 | graph_anomaly_list = [] 123 | for i in range(args["n_graph"]): 124 | print(f"GRAPH ANOMALY {i} >>>>>>>>>>>>>>") 125 | print(graph) 126 | graph_multi_dense = inject_random_block_anomaly( 127 | graph, num_group=30, num_nodes_range=(1, 30), num_nodes_range2=(1, 6) 128 | ) 129 | graph_anomaly_list.append(graph_multi_dense) 130 | print() 131 | 132 | dataset = {"args": args, "graph": graph, "graph_anomaly_list": graph_anomaly_list} 133 | 134 | store_graph(args["name"], dataset) 135 | 136 | 137 | if __name__ == "__main__": 138 | synth_random() 139 | -------------------------------------------------------------------------------- /data_wikipedia.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_sparse.tensor import SparseTensor 6 | 7 | import numpy as np 8 | from anomaly_insert import inject_random_block_anomaly 9 | 10 | from models.data import BipartiteData 11 | 12 | import torch 13 | from sklearn import preprocessing 14 | 15 | import pandas as pd 16 | 17 | # %% 18 | 19 | 20 | def standardize(features: np.ndarray) -> np.ndarray: 21 | scaler = preprocessing.StandardScaler() 22 | z = scaler.fit_transform(features) 23 | return z 24 | 25 | 26 | def prepare_data(): 27 | 28 | cols = ["user_id", "item_id", "timestamp", "state_label"] + [ 29 | f"v{i+1}" for i in range(172) 30 | ] 31 | df = pd.read_csv(f"data/wikipedia.csv", skiprows=1, names=cols) 32 | 33 | # edge 34 | cols_d = {"item_id": [("n_action", "count")]} 35 | for i in range(172): 36 | cols_d[f"v{i+1}"] = [(f"v{i+1}_mean", "mean"), (f"v{i+1}_max", "max")] 37 | 38 | df_edge = df.groupby(["user_id", "item_id"]).agg(cols_d) 39 | df_edge = df_edge.droplevel(axis=1, level=0).reset_index() 40 | df_edge.to_csv(f"data/wikipedia-edge.csv") 41 | 42 | # user 43 | cols_d = {"item_id": [("n_item", "nunique"), ("n_action", "count")]} 44 | for i in range(172): 45 | cols_d[f"v{i+1}"] = [(f"v{i+1}_mean", "mean")] 46 | 47 | df_user = df.groupby(["user_id"]).agg(cols_d) 48 | 49 | df_user = df_user.droplevel(axis=1, level=0).reset_index() 50 | df_user.to_csv(f"data/wikipedia-user.csv") 51 | 52 | # item 53 | cols_d = {"user_id": [("n_user", "nunique"), ("n_action", "count")]} 54 | for i in range(172): 55 | cols_d[f"v{i+1}"] = [(f"v{i+1}_mean", "mean")] 56 | 57 | df_item = df.groupby(["item_id"]).agg(cols_d) 58 | df_item = df_item.droplevel(axis=1, level=0).reset_index() 59 | df_item.to_csv(f"data/wikipedia-item.csv") 60 | 61 | 62 | def create_graph(): 63 | 64 | df_user = pd.read_csv("data/wikipedia-user.csv") 65 | df_item = pd.read_csv("data/wikipedia-item.csv") 66 | df_edge = pd.read_csv("data/wikipedia-edge.csv") 67 | 68 | df_user["uid"] = df_user.index 69 | df_item["iid"] = df_item.index 70 | 71 | df_user_id = df_user[["user_id", "uid"]] 72 | df_item_id = df_item[["item_id", "iid"]] 73 | 74 | df_edge_2 = df_edge.merge( 75 | df_user_id, 76 | on="user_id", 77 | ).merge(df_item_id, on="item_id") 78 | df_edge_2 = df_edge_2.sort_values(["uid", "iid"]) 79 | 80 | uid = torch.tensor(df_edge_2["uid"].to_numpy()) 81 | iid = torch.tensor(df_edge_2["iid"].to_numpy()) 82 | 83 | adj = SparseTensor(row=uid, col=iid) 84 | edge_attr = torch.tensor(standardize(df_edge_2.iloc[:, 3:-2].to_numpy())).float() 85 | 86 | user_attr = torch.tensor(standardize(df_user.iloc[:, 2:-1].to_numpy())).float() 87 | product_attr = torch.tensor(standardize(df_item.iloc[:, 2:-1].to_numpy())).float() 88 | 89 | data = BipartiteData(adj, xu=user_attr, xv=product_attr, xe=edge_attr) 90 | 91 | return data 92 | 93 | 94 | def store_graph(name: str, dataset): 95 | torch.save(dataset, f"storage/{name}.pt") 96 | 97 | 98 | def load_graph(name: str, key: str, id=None): 99 | if id == None: 100 | data = torch.load(f"storage/{name}.pt") 101 | return data[key] 102 | else: 103 | data = torch.load(f"storage/{name}.pt") 104 | return data[key][id] 105 | 106 | 107 | def synth_random(): 108 | # generate nd store data 109 | import argparse 110 | 111 | parser = argparse.ArgumentParser(description="GraphBEAN") 112 | parser.add_argument("--name", type=str, default="wikipedia_anomaly", help="name") 113 | parser.add_argument("--n-graph", type=int, default=10, help="n graph") 114 | 115 | args = vars(parser.parse_args()) 116 | 117 | prepare_data() 118 | graph = create_graph() 119 | store_graph("wikipedia-graph", graph) 120 | # graph = torch.load(f'storage/wikipedia-graph.pt') 121 | 122 | graph_anomaly_list = [] 123 | for i in range(args["n_graph"]): 124 | print(f"GRAPH ANOMALY {i} >>>>>>>>>>>>>>") 125 | print(graph) 126 | graph_multi_dense = inject_random_block_anomaly( 127 | graph, num_group=20, num_nodes_range=(1, 20), num_nodes_range2=(1, 6) 128 | ) 129 | graph_anomaly_list.append(graph_multi_dense) 130 | print() 131 | 132 | dataset = {"args": args, "graph": graph, "graph_anomaly_list": graph_anomaly_list} 133 | 134 | store_graph(args["name"], dataset) 135 | 136 | 137 | if __name__ == "__main__": 138 | synth_random() 139 | -------------------------------------------------------------------------------- /dominant_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import sys 5 | 6 | from sklearn.metrics import roc_curve, precision_recall_curve, auc 7 | 8 | from data_finefoods import load_graph 9 | 10 | import argparse 11 | import os 12 | 13 | import torch 14 | from torch_geometric.data import Data 15 | from torch_scatter import scatter 16 | 17 | from utils.seed import seed_all 18 | 19 | # train a dominant detector 20 | from pygod.models import DOMINANT 21 | 22 | # %% args 23 | 24 | parser = argparse.ArgumentParser(description="DOMINANT") 25 | parser.add_argument("--name", type=str, default="wikipedia_anomaly", help="name") 26 | parser.add_argument( 27 | "--key", type=str, default="graph_anomaly_list", help="key to the data" 28 | ) 29 | parser.add_argument("--id", type=int, default=0, help="id to the data") 30 | parser.add_argument("--n-epoch", type=int, default=200, help="number of epoch") 31 | parser.add_argument( 32 | "--num-neighbors", type=int, default=-1, help="number of neighbors for node" 33 | ) 34 | parser.add_argument("--batch-size", type=int, default=0, help="batch size") 35 | parser.add_argument("--alpha", type=float, default=0.8, help="balance parameter") 36 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") 37 | parser.add_argument("--gpu", type=int, default=0, help="gpu number") 38 | 39 | 40 | args1 = vars(parser.parse_args()) 41 | 42 | args2 = { 43 | "seed": 0, 44 | "hidden_channels": 32, 45 | "dropout_prob": 0.0, 46 | } 47 | 48 | args = {**args1, **args2} 49 | 50 | seed_all(args["seed"]) 51 | 52 | result_dir = "results/" 53 | 54 | # %% data 55 | data = load_graph(args["name"], args["key"], args["id"]) 56 | 57 | u_ch = data.xu.shape[1] 58 | v_ch = data.xv.shape[1] 59 | e_ch = data.xe.shape[1] 60 | 61 | print( 62 | f"Data dimension: U node = {data.xu.shape}; V node = {data.xv.shape}; E edge = {data.xe.shape}; \n" 63 | ) 64 | 65 | # %% model 66 | 67 | xu, xv = data.xu, data.xv 68 | xe, adj = data.xe, data.adj 69 | yu, yv, ye = data.yu, data.yv, data.ye 70 | 71 | 72 | # %% to homogen 73 | nu = xu.shape[0] 74 | nv = xv.shape[0] 75 | nn = nu + nv 76 | 77 | # to homogen 78 | row_h = torch.cat([adj.storage.row(), adj.storage.col() + nu]) 79 | col_h = torch.cat([adj.storage.col() + nu, adj.storage.row()]) 80 | edge_index_h = torch.stack([row_h, col_h]) 81 | xuh = torch.cat( 82 | [ 83 | scatter(xe, adj.storage.row(), dim=0, reduce="max"), 84 | scatter(xe, adj.storage.row(), dim=0, reduce="mean"), 85 | ], 86 | dim=1, 87 | ) 88 | xvh = torch.cat( 89 | [ 90 | scatter(xe, adj.storage.col(), dim=0, reduce="max"), 91 | scatter(xe, adj.storage.col(), dim=0, reduce="mean"), 92 | ], 93 | dim=1, 94 | ) 95 | xh = torch.cat([xuh, xvh], dim=0) 96 | yh = torch.cat([yu, yv], dim=0) 97 | data_h = Data(x=xh, edge_index=edge_index_h, y=yh) 98 | 99 | # %% model 100 | 101 | device = torch.device(f'cuda:{args["gpu"]}' if torch.cuda.is_available() else "cpu") 102 | model = DOMINANT( 103 | hid_dim=args["hidden_channels"], 104 | num_layers=4, 105 | dropout=args["dropout_prob"], 106 | alpha=args["alpha"], 107 | epoch=args["n_epoch"], 108 | lr=args["lr"], 109 | verbose=True, 110 | gpu=args["gpu"], 111 | batch_size=args["batch_size"], 112 | num_neigh=args["num_neighbors"], 113 | ) 114 | 115 | print(args) 116 | print() 117 | 118 | 119 | def auc_eval(pred, y): 120 | 121 | rc_curve = roc_curve(y, pred) 122 | pr_curve = precision_recall_curve(y, pred) 123 | roc_auc = auc(rc_curve[0], rc_curve[1]) 124 | pr_auc = auc(pr_curve[1], pr_curve[0]) 125 | 126 | return roc_auc, pr_auc, rc_curve, pr_curve 127 | 128 | 129 | # %% run training 130 | 131 | print("ready to run") 132 | 133 | model.fit(data_h, yh) 134 | score = model.decision_scores_ 135 | 136 | score_u = score[:nu] 137 | score_v = score[nu:] 138 | score_e_u = score_u[adj.storage.row().numpy()] 139 | score_e_v = score_v[adj.storage.col().numpy()] 140 | score_e = (score_e_u + score_e_v) / 2 141 | 142 | u_roc_auc, u_pr_auc, u_rc_curve, u_pr_curve = auc_eval(score_u, yu.numpy()) 143 | v_roc_auc, v_pr_auc, v_rc_curve, v_pr_curve = auc_eval(score_v, yv.numpy()) 144 | e_roc_auc, e_pr_auc, e_rc_curve, e_pr_curve = auc_eval(score_e, ye.numpy()) 145 | 146 | print( 147 | f"Eval | " 148 | + f"u auc-roc: {u_roc_auc:.4f}, v auc-roc: {v_roc_auc:.4f}, e auc-roc: {e_roc_auc:.4f} | " 149 | + f"u auc-pr {u_pr_auc:.4f}, v auc-pr {v_pr_auc:.4f}, e auc-pr {e_pr_auc:.4f}" 150 | ) 151 | 152 | auc_metrics = { 153 | "u_roc_auc": u_roc_auc, 154 | "u_pr_auc": u_pr_auc, 155 | "v_roc_auc": v_roc_auc, 156 | "v_pr_auc": v_pr_auc, 157 | "e_roc_auc": e_roc_auc, 158 | "e_pr_auc": e_pr_auc, 159 | "u_roc_curve": u_rc_curve, 160 | "u_pr_curve": u_pr_curve, 161 | "v_roc_curve": v_rc_curve, 162 | "v_pr_curve": v_pr_curve, 163 | "e_roc_curve": e_rc_curve, 164 | "e_pr_curve": e_pr_curve, 165 | } 166 | anomaly_score = {"score_u": score_u, "score_v": score_v, "score_e": score_e} 167 | 168 | model_stored = { 169 | "args": args, 170 | "auc_metrics": auc_metrics, 171 | "state_dict": model.model.state_dict(), 172 | } 173 | output_stored = {"args": args, "anomaly_score": anomaly_score} 174 | 175 | print("Saving current results...") 176 | torch.save( 177 | model_stored, 178 | os.path.join( 179 | result_dir, f"dominant-{args['name']}-{args['id']}-alpha-{args['alpha']}-model.th" 180 | ), 181 | ) 182 | torch.save( 183 | output_stored, 184 | os.path.join( 185 | result_dir, 186 | f"dominant-{args['name']}-{args['id']}-alpha-{args['alpha']}-output.th", 187 | ), 188 | ) 189 | 190 | 191 | print() 192 | print(args) 193 | -------------------------------------------------------------------------------- /extract_movies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | with open(f"data/movies.txt", "r", errors="ignore") as infile: 5 | movies = infile.readlines() 6 | 7 | infile.close() 8 | with open(f"data/movies.csv", "w") as out: 9 | out.write( 10 | "product_id,user_id,profile_name,helpfulness_numerator,helpfulness_denominator,score,time,summary,text\n" 11 | ) 12 | 13 | count = 0 14 | out_dict = { 15 | "product/productId": "", 16 | "review/userId": "", 17 | "review/profileName": "", 18 | "review/helpfulness": "", 19 | "review/score": "", 20 | "review/time": "", 21 | "review/summary": "", 22 | "review/text": "", 23 | "hnum": 0, 24 | "hden": 0, 25 | } 26 | 27 | for row in movies: 28 | if row.rstrip() != "": 29 | cells = row.split(":") 30 | if cells[0] == "product/productId": 31 | if len(cells) > 1: 32 | out_dict[cells[0]] = ( 33 | cells[1] 34 | .replace(",", "") 35 | .replace("\n", "") 36 | .replace("
", "") 37 | .replace("\\", "") 38 | .strip() 39 | ) 40 | if count > 0: 41 | output = ( 42 | f"{out_dict['product/productId']},{out_dict['review/userId']},{out_dict['review/profileName']},{out_dict['hnum']},{out_dict['hden']}," 43 | + f"{out_dict['review/score']},{out_dict['review/time']},{out_dict['review/summary']},{out_dict['review/text']}\n" 44 | ) 45 | out.write(output) 46 | count += 1 47 | if count % 1000 == 0: 48 | out.flush() 49 | print(count) 50 | elif cells[0] == "review/helpfulness": 51 | if len(cells) > 1: 52 | if "/" in cells[1]: 53 | hs = cells[1].split("/") 54 | out_dict["hnum"] = int(hs[0]) 55 | out_dict["hden"] = int(hs[1]) 56 | else: 57 | out_dict["hnum"] = int(cells[1]) 58 | out_dict["hden"] = int(cells[1]) 59 | elif cells[0] == "review/text": 60 | if len(cells) > 1: 61 | out_dict[cells[0]] = ( 62 | '"' 63 | + ":".join(cells[1:]) 64 | .replace(",", "") 65 | .replace("\n", "") 66 | .replace("
", "") 67 | .replace("\\", "") 68 | .replace('"', "") 69 | .replace("'", "") 70 | .strip() 71 | + '"' 72 | ) 73 | else: 74 | if len(cells) > 1: 75 | out_dict[cells[0]] = ( 76 | cells[1] 77 | .replace(",", "") 78 | .replace("\n", "") 79 | .replace("
", "") 80 | .replace("\\", "") 81 | .strip() 82 | ) 83 | 84 | output = ( 85 | f"{out_dict['product/productId']},{out_dict['review/userId']},{out_dict['review/profileName']},{out_dict['hnum']},{out_dict['hden']}," 86 | + f"{out_dict['review/score']},{out_dict['review/time']},{out_dict['review/summary']},{out_dict['review/text']}\n" 87 | ) 88 | out.write(output) 89 | print("======= ALL FINISHED ========") 90 | -------------------------------------------------------------------------------- /isoforest_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import sys 5 | 6 | from sklearn.metrics import roc_curve, precision_recall_curve, auc 7 | 8 | from data_finefoods import load_graph 9 | 10 | import argparse 11 | import os 12 | 13 | import torch 14 | from utils.seed import seed_all 15 | 16 | from sklearn.ensemble import IsolationForest 17 | 18 | # %% args 19 | 20 | parser = argparse.ArgumentParser(description="IsolationForest") 21 | parser.add_argument("--name", type=str, default="wikipedia_anomaly", help="name") 22 | parser.add_argument( 23 | "--key", type=str, default="graph_anomaly_list", help="key to the data" 24 | ) 25 | parser.add_argument("--id", type=int, default=0, help="id to the data") 26 | 27 | args1 = vars(parser.parse_args()) 28 | 29 | args2 = { 30 | "seed": 0, 31 | } 32 | 33 | args = {**args1, **args2} 34 | 35 | seed_all(args["seed"]) 36 | 37 | result_dir = "results/" 38 | 39 | 40 | # %% data 41 | data = load_graph(args["name"], args["key"], args["id"]) 42 | 43 | u_ch = data.xu.shape[1] 44 | v_ch = data.xv.shape[1] 45 | e_ch = data.xe.shape[1] 46 | 47 | print( 48 | f"Data dimension: U node = {data.xu.shape}; V node = {data.xv.shape}; E edge = {data.xe.shape}; \n" 49 | ) 50 | 51 | # %% model 52 | 53 | xu, xv = data.xu, data.xv 54 | xe, adj = data.xe, data.adj 55 | yu, yv, ye = data.yu, data.yv, data.ye 56 | 57 | 58 | def train_eval(x, y): 59 | clf = IsolationForest() 60 | clf.fit(x) 61 | score = -clf.score_samples(x) 62 | 63 | rc_curve = roc_curve(y, score) 64 | pr_curve = precision_recall_curve(y, score) 65 | roc_auc = auc(rc_curve[0], rc_curve[1]) 66 | pr_auc = auc(pr_curve[1], pr_curve[0]) 67 | 68 | return roc_auc, pr_auc, rc_curve, pr_curve 69 | 70 | 71 | # %% isolation forest 72 | 73 | u_roc_auc, u_pr_auc, u_rc_curve, u_pr_curve = train_eval(xu.numpy(), yu.numpy()) 74 | v_roc_auc, v_pr_auc, v_rc_curve, v_pr_curve = train_eval(xv.numpy(), yv.numpy()) 75 | e_roc_auc, e_pr_auc, e_rc_curve, e_pr_curve = train_eval(xe.numpy(), ye.numpy()) 76 | 77 | print(args) 78 | 79 | print( 80 | f"Eval, " 81 | + f"u auc-roc: {u_roc_auc:.4f}, v auc-roc: {v_roc_auc:.4f}, e auc-roc: {e_roc_auc:.4f} | " 82 | + f"u auc-pr {u_pr_auc:.4f}, v auc-pr {v_pr_auc:.4f}, e auc-pr {e_pr_auc:.4f}" 83 | ) 84 | 85 | 86 | auc_metrics = { 87 | "u_roc_auc": u_roc_auc, 88 | "u_pr_auc": u_pr_auc, 89 | "v_roc_auc": v_roc_auc, 90 | "v_pr_auc": v_pr_auc, 91 | "e_roc_auc": e_roc_auc, 92 | "e_pr_auc": e_pr_auc, 93 | "u_roc_curve": u_rc_curve, 94 | "u_pr_curve": u_pr_curve, 95 | "v_roc_curve": v_rc_curve, 96 | "v_pr_curve": v_pr_curve, 97 | "e_roc_curve": e_rc_curve, 98 | "e_pr_curve": e_pr_curve, 99 | } 100 | 101 | output_stored = { 102 | "args": args, 103 | "auc_metrics": auc_metrics, 104 | } 105 | 106 | print("Saving current results...") 107 | torch.save( 108 | output_stored, 109 | os.path.join(result_dir, f"isoforest-{args['name']}-{args['id']}-output.th"), 110 | ) 111 | -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | from typing import Optional, Tuple 5 | import torch 6 | 7 | from torch import Tensor 8 | import torch.nn as nn 9 | from torch_sparse import SparseTensor, matmul 10 | from torch_scatter import scatter 11 | 12 | from torch_geometric.nn.conv import MessagePassing 13 | from torch_geometric.nn.dense.linear import Linear 14 | 15 | from torch_geometric.typing import PairTensor, OptTensor 16 | 17 | 18 | class BEANConv(MessagePassing): 19 | def __init__( 20 | self, 21 | in_channels: Tuple[int, int, Optional[int]], 22 | out_channels: Tuple[int, int, Optional[int]], 23 | node_self_loop: bool = True, 24 | normalize: bool = True, 25 | bias: bool = True, 26 | **kwargs 27 | ): 28 | 29 | super(BEANConv, self).__init__(**kwargs) 30 | 31 | self.in_channels = in_channels 32 | self.out_channels = out_channels 33 | self.node_self_loop = node_self_loop 34 | self.normalize = normalize 35 | 36 | self.input_has_edge_channel = len(in_channels) == 3 37 | self.output_has_edge_channel = len(out_channels) == 3 38 | 39 | if self.input_has_edge_channel: 40 | if self.node_self_loop: 41 | self.in_channels_u = ( 42 | in_channels[0] + 2 * in_channels[1] + 2 * in_channels[2] 43 | ) 44 | self.in_channels_v = ( 45 | 2 * in_channels[0] + in_channels[1] + 2 * in_channels[2] 46 | ) 47 | else: 48 | self.in_channels_u = 2 * in_channels[1] + 2 * in_channels[2] 49 | self.in_channels_v = 2 * in_channels[0] + 2 * in_channels[2] 50 | self.in_channels_e = in_channels[0] + in_channels[1] + in_channels[2] 51 | else: 52 | if self.node_self_loop: 53 | self.in_channels_u = in_channels[0] + 2 * in_channels[1] 54 | self.in_channels_v = 2 * in_channels[0] + in_channels[1] 55 | else: 56 | self.in_channels_u = 2 * in_channels[1] 57 | self.in_channels_v = 2 * in_channels[0] 58 | self.in_channels_e = in_channels[0] + in_channels[1] 59 | 60 | self.lin_u = Linear(self.in_channels_u, out_channels[0], bias=bias) 61 | self.lin_v = Linear(self.in_channels_v, out_channels[1], bias=bias) 62 | if self.output_has_edge_channel: 63 | self.lin_e = Linear(self.in_channels_e, out_channels[2], bias=bias) 64 | 65 | if normalize: 66 | self.bn_u = nn.BatchNorm1d(out_channels[0]) 67 | self.bn_v = nn.BatchNorm1d(out_channels[1]) 68 | if self.output_has_edge_channel: 69 | self.bn_e = nn.BatchNorm1d(out_channels[2]) 70 | 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | self.lin_u.reset_parameters() 75 | self.lin_v.reset_parameters() 76 | if self.output_has_edge_channel: 77 | self.lin_e.reset_parameters() 78 | 79 | def forward( 80 | self, x: PairTensor, adj: SparseTensor, xe: OptTensor = None 81 | ) -> Tuple[PairTensor, Tensor]: 82 | """""" 83 | 84 | assert self.input_has_edge_channel == (xe is not None) 85 | 86 | # propagate_type: (x: PairTensor) 87 | (out_u, out_v), out_e = self.propagate(adj, x=x, xe=xe) 88 | 89 | # lin layer 90 | out_u = self.lin_u(out_u) 91 | out_v = self.lin_v(out_v) 92 | if self.output_has_edge_channel: 93 | out_e = self.lin_e(out_e) 94 | 95 | if self.normalize: 96 | out_u = self.bn_u(out_u) 97 | out_v = self.bn_v(out_v) 98 | if self.output_has_edge_channel: 99 | out_e = self.bn_e(out_e) 100 | 101 | return (out_u, out_v), out_e 102 | 103 | def message_and_aggregate( 104 | self, adj: SparseTensor, x: PairTensor, xe: OptTensor 105 | ) -> Tuple[PairTensor, Tensor]: 106 | 107 | xu, xv = x 108 | adj = adj.set_value(None, layout=None) 109 | 110 | # messages node to node 111 | msg_v2u_mean = matmul(adj, xv, reduce="mean") 112 | msg_v2u_sum = matmul(adj, xv, reduce="max") 113 | 114 | msg_u2v_mean = matmul(adj.t(), xu, reduce="mean") 115 | msg_u2v_sum = matmul(adj.t(), xu, reduce="max") 116 | 117 | # messages edge to node 118 | if xe is not None: 119 | msg_e2u_mean = scatter(xe, adj.storage.row(), dim=0, reduce="mean") 120 | msg_e2u_sum = scatter(xe, adj.storage.row(), dim=0, reduce="max") 121 | 122 | msg_e2v_mean = scatter(xe, adj.storage.col(), dim=0, reduce="mean") 123 | msg_e2v_sum = scatter(xe, adj.storage.col(), dim=0, reduce="max") 124 | 125 | # collect all msg (including self loop) 126 | msg_2e = None 127 | if xe is not None: 128 | if self.node_self_loop: 129 | msg_2u = torch.cat( 130 | (xu, msg_v2u_mean, msg_v2u_sum, msg_e2u_mean, msg_e2u_sum), dim=1 131 | ) 132 | msg_2v = torch.cat( 133 | (xv, msg_u2v_mean, msg_u2v_sum, msg_e2v_mean, msg_e2v_sum), dim=1 134 | ) 135 | else: 136 | msg_2u = torch.cat( 137 | (msg_v2u_mean, msg_v2u_sum, msg_e2u_mean, msg_e2u_sum), dim=1 138 | ) 139 | msg_2v = torch.cat( 140 | (msg_u2v_mean, msg_u2v_sum, msg_e2v_mean, msg_e2v_sum), dim=1 141 | ) 142 | 143 | if self.output_has_edge_channel: 144 | msg_2e = torch.cat( 145 | (xe, xu[adj.storage.row()], xv[adj.storage.col()]), dim=1 146 | ) 147 | else: 148 | if self.node_self_loop: 149 | msg_2u = torch.cat((xu, msg_v2u_mean, msg_v2u_sum), dim=1) 150 | msg_2v = torch.cat((xv, msg_u2v_mean, msg_u2v_sum), dim=1) 151 | else: 152 | msg_2u = torch.cat((msg_v2u_mean, msg_v2u_sum), dim=1) 153 | msg_2v = torch.cat((msg_u2v_mean, msg_u2v_sum), dim=1) 154 | 155 | if self.output_has_edge_channel: 156 | msg_2e = torch.cat( 157 | (xu[adj.storage.row()], xv[adj.storage.col()]), dim=1 158 | ) 159 | 160 | return (msg_2u, msg_2v), msg_2e 161 | -------------------------------------------------------------------------------- /models/conv_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | from typing import List, Optional, Tuple 5 | import torch 6 | 7 | from torch import Tensor 8 | import torch.nn as nn 9 | from torch_sparse import SparseTensor, matmul 10 | from torch_scatter import scatter 11 | 12 | from torch_geometric.nn.conv import MessagePassing 13 | from torch_geometric.nn.dense.linear import Linear 14 | 15 | from torch_geometric.typing import PairTensor, OptTensor 16 | 17 | from models.sampler import BEANAdjacency 18 | 19 | 20 | class BEANConvSample(torch.nn.Module): 21 | def __init__( 22 | self, 23 | in_channels: Tuple[int, int, Optional[int]], 24 | out_channels: Tuple[int, int, Optional[int]], 25 | node_self_loop: bool = True, 26 | normalize: bool = True, 27 | bias: bool = True, 28 | **kwargs, 29 | ): 30 | 31 | super().__init__(**kwargs) 32 | 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.node_self_loop = node_self_loop 36 | self.normalize = normalize 37 | 38 | self.input_has_edge_channel = len(in_channels) == 3 39 | self.output_has_edge_channel = len(out_channels) == 3 40 | 41 | self.v2u_conv = BEANConvNode( 42 | in_channels, 43 | out_channels[0], 44 | flow="v->u", 45 | node_self_loop=node_self_loop, 46 | normalize=normalize, 47 | bias=bias, 48 | **kwargs, 49 | ) 50 | 51 | self.u2v_conv = BEANConvNode( 52 | in_channels, 53 | out_channels[1], 54 | flow="u->v", 55 | node_self_loop=node_self_loop, 56 | normalize=normalize, 57 | bias=bias, 58 | **kwargs, 59 | ) 60 | 61 | if self.output_has_edge_channel: 62 | self.e_conv = BEANConvEdge( 63 | in_channels, 64 | out_channels[2], 65 | node_self_loop=node_self_loop, 66 | normalize=normalize, 67 | bias=bias, 68 | **kwargs, 69 | ) 70 | 71 | def forward( 72 | self, 73 | xu: PairTensor, 74 | xv: PairTensor, 75 | adj: BEANAdjacency, 76 | xe: Optional[Tuple[Tensor, Tensor, Tensor]] = None, 77 | ) -> Tuple[Tensor, Tensor, Tensor]: 78 | 79 | # source and target 80 | xus, xut = xu 81 | xvs, xvt = xv 82 | 83 | # xe 84 | if xe is not None: 85 | xe_e, xe_v2u, xe_u2v = xe 86 | 87 | out_u = self.v2u_conv((xut, xvs), adj.adj_v2u.adj, xe_v2u) 88 | out_v = self.u2v_conv((xus, xvt), adj.adj_u2v.adj, xe_u2v) 89 | 90 | out_e = None 91 | if self.output_has_edge_channel: 92 | out_e = self.e_conv((xut, xvt), adj.adj_e.adj, xe_e) 93 | 94 | return out_u, out_v, out_e 95 | 96 | 97 | class BEANConvNode(MessagePassing): 98 | def __init__( 99 | self, 100 | in_channels: Tuple[int, int, Optional[int]], 101 | out_channels: int, 102 | flow: str = "v->u", 103 | node_self_loop: bool = True, 104 | normalize: bool = True, 105 | bias: bool = True, 106 | agg: List[str] = ["mean", "max"], 107 | **kwargs, 108 | ): 109 | 110 | super().__init__(**kwargs) 111 | 112 | self.in_channels = in_channels 113 | self.out_channels = out_channels 114 | self.flow = flow 115 | self.node_self_loop = node_self_loop 116 | self.normalize = normalize 117 | self.agg = agg 118 | 119 | self.input_has_edge_channel = len(in_channels) == 3 120 | 121 | n_agg = len(agg) 122 | # calculate in channels 123 | if self.input_has_edge_channel: 124 | if self.node_self_loop: 125 | if flow == "v->u": 126 | self.in_channels_all = ( 127 | in_channels[0] + n_agg * in_channels[1] + n_agg * in_channels[2] 128 | ) 129 | else: 130 | self.in_channels_all = ( 131 | n_agg * in_channels[0] + in_channels[1] + n_agg * in_channels[2] 132 | ) 133 | else: 134 | if flow == "v->u": 135 | self.in_channels_all = ( 136 | n_agg * in_channels[1] + n_agg * in_channels[2] 137 | ) 138 | else: 139 | self.in_channels_all = ( 140 | n_agg * in_channels[0] + n_agg * in_channels[2] 141 | ) 142 | else: 143 | if self.node_self_loop: 144 | if flow == "v->u": 145 | self.in_channels_all = in_channels[0] + n_agg * in_channels[1] 146 | else: 147 | self.in_channels_all = n_agg * in_channels[0] + in_channels[1] 148 | else: 149 | if flow == "v->u": 150 | self.in_channels_all = n_agg * in_channels[1] 151 | else: 152 | self.in_channels_all = n_agg * in_channels[0] 153 | 154 | self.lin = Linear(self.in_channels_all, out_channels, bias=bias) 155 | 156 | if normalize: 157 | self.bn = nn.BatchNorm1d(out_channels) 158 | 159 | self.reset_parameters() 160 | 161 | def reset_parameters(self): 162 | self.lin.reset_parameters() 163 | 164 | def forward(self, x: PairTensor, adj: SparseTensor, xe: OptTensor = None) -> Tensor: 165 | """""" 166 | 167 | assert self.input_has_edge_channel == (xe is not None) 168 | 169 | # propagate_type: (x: PairTensor) 170 | out = self.propagate(adj, x=x, xe=xe) 171 | 172 | # lin layer 173 | out = self.lin(out) 174 | if self.normalize: 175 | out = self.bn(out) 176 | 177 | return out 178 | 179 | def message_and_aggregate( 180 | self, adj: SparseTensor, x: PairTensor, xe: OptTensor 181 | ) -> Tensor: 182 | 183 | xu, xv = x 184 | adj = adj.set_value(None, layout=None) 185 | 186 | ## Node V to node U 187 | if self.flow == "v->u": 188 | # messages node to node 189 | msg_v2u_list = [matmul(adj, xv, reduce=ag) for ag in self.agg] 190 | 191 | # messages edge to node 192 | if xe is not None: 193 | msg_e2u_list = [ 194 | scatter(xe, adj.storage.row(), dim=0, reduce=ag) for ag in self.agg 195 | ] 196 | 197 | # collect all msg 198 | if xe is not None: 199 | if self.node_self_loop: 200 | if xu.shape[0] != msg_e2u_list[0].shape[0]: 201 | print( 202 | f"xu: {xu.shape} | msg_v2u : {msg_v2u_list[0].shape} | msg_e2u_sum : {msg_e2u_list[0].shape}" 203 | ) 204 | msg_2u = torch.cat((xu, *msg_v2u_list, *msg_e2u_list), dim=1) 205 | else: 206 | msg_2u = torch.cat((*msg_v2u_list, *msg_e2u_list), dim=1) 207 | else: 208 | if self.node_self_loop: 209 | msg_2u = torch.cat((xu, *msg_v2u_list), dim=1) 210 | else: 211 | msg_2u = torch.cat((*msg_v2u_list,), dim=1) 212 | 213 | return msg_2u 214 | 215 | ## Node U to node V 216 | else: 217 | msg_u2v_list = [matmul(adj.t(), xu, reduce=ag) for ag in self.agg] 218 | 219 | # messages edge to node 220 | if xe is not None: 221 | msg_e2v_list = [ 222 | scatter(xe, adj.storage.col(), dim=0, reduce=ag) for ag in self.agg 223 | ] 224 | 225 | # collect all msg (including self loop) 226 | if xe is not None: 227 | if self.node_self_loop: 228 | msg_2v = torch.cat((xv, *msg_u2v_list, *msg_e2v_list), dim=1) 229 | else: 230 | msg_2v = torch.cat((*msg_u2v_list, *msg_e2v_list), dim=1) 231 | else: 232 | if self.node_self_loop: 233 | msg_2v = torch.cat((xv, *msg_u2v_list), dim=1) 234 | else: 235 | msg_2v = torch.cat((*msg_u2v_list,), dim=1) 236 | 237 | return msg_2v 238 | 239 | 240 | class BEANConvEdge(MessagePassing): 241 | def __init__( 242 | self, 243 | in_channels: Tuple[int, int, Optional[int]], 244 | out_channels: int, 245 | node_self_loop: bool = True, 246 | normalize: bool = True, 247 | bias: bool = True, 248 | **kwargs, 249 | ): 250 | 251 | super().__init__(**kwargs) 252 | 253 | self.in_channels = in_channels 254 | self.out_channels = out_channels 255 | self.node_self_loop = node_self_loop 256 | self.normalize = normalize 257 | 258 | self.input_has_edge_channel = len(in_channels) == 3 259 | 260 | if self.input_has_edge_channel: 261 | self.in_channels_e = in_channels[0] + in_channels[1] + in_channels[2] 262 | else: 263 | self.in_channels_e = in_channels[0] + in_channels[1] 264 | 265 | self.lin_e = Linear(self.in_channels_e, out_channels, bias=bias) 266 | 267 | if normalize: 268 | self.bn_e = nn.BatchNorm1d(out_channels) 269 | 270 | self.reset_parameters() 271 | 272 | def reset_parameters(self): 273 | self.lin_e.reset_parameters() 274 | 275 | def forward(self, x: PairTensor, adj: SparseTensor, xe: Tensor) -> Tensor: 276 | """""" 277 | 278 | # propagate_type: (x: PairTensor) 279 | out_e = self.propagate(adj, x=x, xe=xe) 280 | 281 | # lin layer 282 | out_e = self.lin_e(out_e) 283 | 284 | if self.normalize: 285 | out_e = self.bn_e(out_e) 286 | 287 | return out_e 288 | 289 | def message_and_aggregate( 290 | self, adj: SparseTensor, x: PairTensor, xe: OptTensor 291 | ) -> Tensor: 292 | 293 | xu, xv = x 294 | adj = adj.set_value(None, layout=None) 295 | 296 | # collect all msg (including self loop) 297 | if xe is not None: 298 | msg_2e = torch.cat( 299 | (xe, xu[adj.storage.row()], xv[adj.storage.col()]), dim=1 300 | ) 301 | else: 302 | msg_2e = torch.cat((xu[adj.storage.row()], xv[adj.storage.col()]), dim=1) 303 | 304 | return msg_2e 305 | -------------------------------------------------------------------------------- /models/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_geometric.data import Data 6 | from torch_geometric.typing import OptTensor 7 | from torch_sparse.tensor import SparseTensor 8 | 9 | 10 | class BipartiteData(Data): 11 | def __init__( 12 | self, 13 | adj: SparseTensor, 14 | xu: OptTensor = None, 15 | xv: OptTensor = None, 16 | xe: OptTensor = None, 17 | **kwargs 18 | ): 19 | super().__init__() 20 | self.adj = adj 21 | self.xu = xu 22 | self.xv = xv 23 | self.xe = xe 24 | 25 | for key, value in kwargs.items(): 26 | setattr(self, key, value) 27 | 28 | def __inc__(self, key, value, *args, **kwargs): 29 | if key == "adj": 30 | return torch.tensor([[self.xu.size(0)], [self.xv.size(0)]]) 31 | else: 32 | return super().__inc__(key, value, *args, **kwargs) 33 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | from typing import Dict, Tuple 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from torch_sparse import SparseTensor 9 | 10 | 11 | def reconstruction_loss( 12 | xu: Tensor, 13 | xv: Tensor, 14 | xe: Tensor, 15 | adj: SparseTensor, 16 | edge_pred_samples: SparseTensor, 17 | out: Dict[str, Tensor], 18 | xe_loss_weight: float = 1.0, 19 | structure_loss_weight: float = 1.0, 20 | ) -> Tuple[Tensor, Dict[str, Tensor]]: 21 | # feature mse 22 | xu_loss = F.mse_loss(xu, out["xu"]) 23 | xv_loss = F.mse_loss(xv, out["xv"]) 24 | xe_loss = F.mse_loss(xe, out["xe"]) 25 | feature_loss = xu_loss + xv_loss + xe_loss_weight * xe_loss 26 | 27 | # structure loss 28 | edge_gt = (edge_pred_samples.storage.value() > 0).float() 29 | structure_loss = F.binary_cross_entropy(out["eprob"], edge_gt) 30 | 31 | loss = feature_loss + structure_loss_weight * structure_loss 32 | 33 | loss_component = { 34 | "xu": xu_loss, 35 | "xv": xv_loss, 36 | "xe": xe_loss, 37 | "e": structure_loss, 38 | "total": loss, 39 | } 40 | 41 | return loss, loss_component 42 | -------------------------------------------------------------------------------- /models/net.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from torch_sparse import SparseTensor 10 | from torch_geometric.nn.dense.linear import Linear 11 | 12 | from typing import Tuple, Union, Dict 13 | 14 | from models.conv import BEANConv 15 | 16 | 17 | def make_tuple(x: Union[int, Tuple[int, int, int], Tuple[int, int]], repeat: int = 3): 18 | if isinstance(x, int): 19 | if repeat == 2: 20 | return (x, x) 21 | else: 22 | return (x, x, x) 23 | else: 24 | return x 25 | 26 | 27 | def apply_relu_dropout(x: Tensor, dropout_prob: float, training: bool) -> Tensor: 28 | x = F.relu(x) 29 | if dropout_prob > 0.0: 30 | x = F.dropout(x, p=dropout_prob, training=training) 31 | return x 32 | 33 | 34 | class GraphBEAN(nn.Module): 35 | def __init__( 36 | self, 37 | in_channels: Union[int, Tuple[int, int, int]], 38 | hidden_channels: Union[int, Tuple[int, int, int]] = 32, 39 | latent_channels: Union[int, Tuple[int, int]] = 64, 40 | edge_pred_latent: int = 64, 41 | n_layers_encoder: int = 4, 42 | n_layers_decoder: int = 4, 43 | n_layers_mlp: int = 4, 44 | dropout_prob: float = 0.0, 45 | ): 46 | 47 | super().__init__() 48 | 49 | self.in_channels = make_tuple(in_channels) 50 | self.hidden_channels = make_tuple(hidden_channels) 51 | self.latent_channels = make_tuple(latent_channels, 2) 52 | self.edge_pred_latent = edge_pred_latent 53 | self.n_layers_encoder = n_layers_encoder 54 | self.n_layers_decoder = n_layers_decoder 55 | self.n_layers_mlp = n_layers_mlp 56 | self.dropout_prob = dropout_prob 57 | 58 | self.create_encoder() 59 | self.create_feature_decoder() 60 | self.create_structure_decoder() 61 | 62 | def create_encoder(self): 63 | self.encoder_convs = nn.ModuleList() 64 | for i in range(self.n_layers_encoder): 65 | if i == 0: 66 | in_channels = self.in_channels 67 | out_channels = self.hidden_channels 68 | elif i == self.n_layers_encoder - 1: 69 | in_channels = self.hidden_channels 70 | out_channels = self.latent_channels 71 | else: 72 | in_channels = self.hidden_channels 73 | out_channels = self.hidden_channels 74 | 75 | if i == self.n_layers_encoder - 1: 76 | self.encoder_convs.append( 77 | BEANConv(in_channels, out_channels, node_self_loop=False) 78 | ) 79 | else: 80 | self.encoder_convs.append( 81 | BEANConv(in_channels, out_channels, node_self_loop=True) 82 | ) 83 | 84 | def create_feature_decoder(self): 85 | self.decoder_convs = nn.ModuleList() 86 | for i in range(self.n_layers_decoder): 87 | if i == 0: 88 | in_channels = self.latent_channels 89 | out_channels = self.hidden_channels 90 | elif i == self.n_layers_decoder - 1: 91 | in_channels = self.hidden_channels 92 | out_channels = self.in_channels 93 | else: 94 | in_channels = self.hidden_channels 95 | out_channels = self.hidden_channels 96 | 97 | self.decoder_convs.append(BEANConv(in_channels, out_channels)) 98 | 99 | def create_structure_decoder(self): 100 | self.u_mlp_layers = nn.ModuleList() 101 | self.v_mlp_layers = nn.ModuleList() 102 | 103 | for i in range(self.n_layers_mlp): 104 | if i == 0: 105 | in_channels = self.latent_channels 106 | else: 107 | in_channels = (self.edge_pred_latent, self.edge_pred_latent) 108 | out_channels = self.edge_pred_latent 109 | 110 | self.u_mlp_layers.append(Linear(in_channels[0], out_channels)) 111 | 112 | self.v_mlp_layers.append(Linear(in_channels[1], out_channels)) 113 | 114 | def forward( 115 | self, 116 | xu: Tensor, 117 | xv: Tensor, 118 | xe: Tensor, 119 | adj: SparseTensor, 120 | edge_pred_samples: SparseTensor, 121 | ) -> Dict[str, Tensor]: 122 | 123 | # encoder 124 | for i, conv in enumerate(self.encoder_convs): 125 | (xu, xv), xe = conv((xu, xv), adj, xe=xe) 126 | if i != self.n_layers_encoder - 1: 127 | xu = apply_relu_dropout(xu, self.dropout_prob, self.training) 128 | xv = apply_relu_dropout(xv, self.dropout_prob, self.training) 129 | xe = apply_relu_dropout(xe, self.dropout_prob, self.training) 130 | 131 | # get latent vars 132 | zu, zv = xu, xv 133 | 134 | # feature decoder 135 | for i, conv in enumerate(self.decoder_convs): 136 | (xu, xv), xe = conv((xu, xv), adj, xe=xe) 137 | if i != self.n_layers_decoder - 1: 138 | xu = apply_relu_dropout(xu, self.dropout_prob, self.training) 139 | xv = apply_relu_dropout(xv, self.dropout_prob, self.training) 140 | xe = apply_relu_dropout(xe, self.dropout_prob, self.training) 141 | 142 | # structure decoder 143 | zu2, zv2 = zu, zv 144 | for i, layer in enumerate(self.u_mlp_layers): 145 | zu2 = layer(zu2) 146 | if i != self.n_layers_mlp - 1: 147 | zu2 = apply_relu_dropout(zu2, self.dropout_prob, self.training) 148 | 149 | for i, layer in enumerate(self.v_mlp_layers): 150 | zv2 = layer(zv2) 151 | if i != self.n_layers_mlp - 1: 152 | zv2 = apply_relu_dropout(zv2, self.dropout_prob, self.training) 153 | 154 | zu2_edge = zu2[edge_pred_samples.storage.row()] 155 | zv2_edge = zv2[edge_pred_samples.storage.col()] 156 | 157 | eprob = torch.sigmoid(torch.sum(zu2_edge * zv2_edge, dim=1)) 158 | 159 | # collect results 160 | result = {"xu": xu, "xv": xv, "xe": xe, "zu": zu, "zv": zv, "eprob": eprob} 161 | 162 | return result 163 | -------------------------------------------------------------------------------- /models/net_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from torch_sparse import SparseTensor 10 | from torch_geometric.nn.dense.linear import Linear 11 | 12 | from typing import List, Tuple, Union, Dict 13 | 14 | from models.conv_sample import BEANConvSample 15 | from models.sampler import BEANAdjacency, BipartiteNeighborSampler, EdgeLoader 16 | from utils.sparse_combine import xe_split3 17 | 18 | from tqdm import tqdm 19 | 20 | 21 | def make_tuple(x: Union[int, Tuple[int, int, int], Tuple[int, int]], repeat: int = 3): 22 | if isinstance(x, int): 23 | if repeat == 2: 24 | return (x, x) 25 | else: 26 | return (x, x, x) 27 | else: 28 | return x 29 | 30 | 31 | def apply_relu_dropout(x: Tensor, dropout_prob: float, training: bool) -> Tensor: 32 | x = F.relu(x) 33 | if dropout_prob > 0.0: 34 | x = F.dropout(x, p=dropout_prob, training=training) 35 | return x 36 | 37 | 38 | class GraphBEANSampled(nn.Module): 39 | def __init__( 40 | self, 41 | in_channels: Union[int, Tuple[int, int, int]], 42 | hidden_channels: Union[int, Tuple[int, int, int]] = 32, 43 | latent_channels: Union[int, Tuple[int, int]] = 64, 44 | edge_pred_latent: int = 64, 45 | n_layers_encoder: int = 4, 46 | n_layers_decoder: int = 4, 47 | n_layers_mlp: int = 4, 48 | dropout_prob: float = 0.0, 49 | ): 50 | 51 | super().__init__() 52 | 53 | self.in_channels = make_tuple(in_channels) 54 | self.hidden_channels = make_tuple(hidden_channels) 55 | self.latent_channels = make_tuple(latent_channels, 2) 56 | self.edge_pred_latent = edge_pred_latent 57 | self.n_layers_encoder = n_layers_encoder 58 | self.n_layers_decoder = n_layers_decoder 59 | self.n_layers_mlp = n_layers_mlp 60 | self.dropout_prob = dropout_prob 61 | 62 | self.create_encoder() 63 | self.create_feature_decoder() 64 | self.create_structure_decoder() 65 | 66 | def create_encoder(self): 67 | self.encoder_convs = nn.ModuleList() 68 | for i in range(self.n_layers_encoder): 69 | if i == 0: 70 | in_channels = self.in_channels 71 | out_channels = self.hidden_channels 72 | elif i == self.n_layers_encoder - 1: 73 | in_channels = self.hidden_channels 74 | out_channels = self.latent_channels 75 | else: 76 | in_channels = self.hidden_channels 77 | out_channels = self.hidden_channels 78 | 79 | if i == self.n_layers_encoder - 1: 80 | self.encoder_convs.append( 81 | BEANConvSample(in_channels, out_channels, node_self_loop=False) 82 | ) 83 | else: 84 | self.encoder_convs.append( 85 | BEANConvSample(in_channels, out_channels, node_self_loop=True) 86 | ) 87 | 88 | def create_feature_decoder(self): 89 | self.decoder_convs = nn.ModuleList() 90 | for i in range(self.n_layers_decoder): 91 | if i == 0: 92 | in_channels = self.latent_channels 93 | out_channels = self.hidden_channels 94 | elif i == self.n_layers_decoder - 1: 95 | in_channels = self.hidden_channels 96 | out_channels = self.in_channels 97 | else: 98 | in_channels = self.hidden_channels 99 | out_channels = self.hidden_channels 100 | 101 | self.decoder_convs.append(BEANConvSample(in_channels, out_channels)) 102 | 103 | def create_structure_decoder(self): 104 | self.u_mlp_layers = nn.ModuleList() 105 | self.v_mlp_layers = nn.ModuleList() 106 | 107 | for i in range(self.n_layers_mlp): 108 | if i == 0: 109 | in_channels = self.latent_channels 110 | else: 111 | in_channels = (self.edge_pred_latent, self.edge_pred_latent) 112 | out_channels = self.edge_pred_latent 113 | 114 | self.u_mlp_layers.append(Linear(in_channels[0], out_channels)) 115 | 116 | self.v_mlp_layers.append(Linear(in_channels[1], out_channels)) 117 | 118 | def forward( 119 | self, 120 | xu: Tensor, 121 | xv: Tensor, 122 | xe: Tensor, 123 | bean_adjs: List[BEANAdjacency], 124 | e_flags: List[SparseTensor], 125 | edge_pred_samples: SparseTensor, 126 | ) -> Dict[str, Tensor]: 127 | 128 | assert self.n_layers_encoder + self.n_layers_decoder == len(bean_adjs) 129 | 130 | # encoder 131 | for i, conv in enumerate(self.encoder_convs): 132 | badj = bean_adjs[i] 133 | e_flag = e_flags[i] 134 | 135 | # target size 136 | n_ut = badj.adj_v2u.size[0] 137 | n_vt = badj.adj_u2v.size[1] 138 | 139 | # get xut and xvt 140 | xus, xut = xu, xu[:n_ut] # target nodes are always placed first 141 | xvs, xvt = xv, xv[:n_vt] 142 | 143 | # get xe 144 | xe_e, xe_v2u, xe_u2v = xe_split3(xe, e_flag.storage.value()) 145 | 146 | # do convolution 147 | xu, xv, xe = conv( 148 | xu=(xus, xut), xv=(xvs, xvt), adj=badj, xe=(xe_e, xe_v2u, xe_u2v) 149 | ) 150 | 151 | if i != self.n_layers_encoder - 1: 152 | xu = apply_relu_dropout(xu, self.dropout_prob, self.training) 153 | xv = apply_relu_dropout(xv, self.dropout_prob, self.training) 154 | xe = apply_relu_dropout(xe, self.dropout_prob, self.training) 155 | 156 | # extract latent vars (only target nodes) 157 | last_badj = bean_adjs[-1] 158 | n_u_target = last_badj.adj_v2u.size[0] 159 | n_v_target = last_badj.adj_u2v.size[1] 160 | # get latent vars 161 | zu, zv = xu[:n_u_target], xv[:n_v_target] 162 | 163 | # feature decoder 164 | for i, conv in enumerate(self.decoder_convs): 165 | 166 | badj = bean_adjs[self.n_layers_encoder + i] 167 | e_flag = e_flags[self.n_layers_encoder + i] 168 | 169 | # target size 170 | n_ut = badj.adj_v2u.size[0] 171 | n_vt = badj.adj_u2v.size[1] 172 | 173 | # get xut and xvt 174 | xus, xut = xu, xu[:n_ut] # target nodes are always placed first 175 | xvs, xvt = xv, xv[:n_vt] 176 | 177 | # get xe 178 | if xe is not None: 179 | xe_e, xe_v2u, xe_u2v = xe_split3(xe, e_flag.storage.value()) 180 | else: 181 | xe_e, xe_v2u, xe_u2v = None, None, None 182 | 183 | # do convolution 184 | xu, xv, xe = conv( 185 | xu=(xus, xut), xv=(xvs, xvt), adj=badj, xe=(xe_e, xe_v2u, xe_u2v) 186 | ) 187 | 188 | if i != self.n_layers_decoder - 1: 189 | xu = apply_relu_dropout(xu, self.dropout_prob, self.training) 190 | xv = apply_relu_dropout(xv, self.dropout_prob, self.training) 191 | xe = apply_relu_dropout(xe, self.dropout_prob, self.training) 192 | 193 | # structure decoder 194 | zu2, zv2 = zu, zv 195 | for i, layer in enumerate(self.u_mlp_layers): 196 | zu2 = layer(zu2) 197 | if i != self.n_layers_mlp - 1: 198 | zu2 = apply_relu_dropout(zu2, self.dropout_prob, self.training) 199 | 200 | for i, layer in enumerate(self.v_mlp_layers): 201 | zv2 = layer(zv2) 202 | if i != self.n_layers_mlp - 1: 203 | zv2 = apply_relu_dropout(zv2, self.dropout_prob, self.training) 204 | 205 | zu2_edge = zu2[edge_pred_samples.storage.row()] 206 | zv2_edge = zv2[edge_pred_samples.storage.col()] 207 | 208 | eprob = torch.sigmoid(torch.sum(zu2_edge * zv2_edge, dim=1)) 209 | 210 | # collect results 211 | result = {"xu": xu, "xv": xv, "xe": xe, "zu": zu, "zv": zv, "eprob": eprob} 212 | 213 | return result 214 | 215 | def apply_conv(self, conv, dir_adj, xu_all, xv_all, xe_all, device): 216 | xu = xu_all[dir_adj.u_id].to(device) 217 | xv = xv_all[dir_adj.v_id].to(device) 218 | xe = xe_all[dir_adj.e_id].to(device) if xe_all is not None else None 219 | adj = dir_adj.adj.to(device) 220 | 221 | out = conv((xu, xv), adj, xe) 222 | 223 | return out 224 | 225 | def inference( 226 | self, 227 | xu_all: Tensor, 228 | xv_all: Tensor, 229 | xe_all: Tensor, 230 | adj_all: SparseTensor, 231 | edge_pred_samples: SparseTensor, 232 | batch_sizes: Tuple[int, int, int], 233 | device, 234 | progress_bar: bool = True, 235 | **kwargs, 236 | ) -> Dict[str, Tensor]: 237 | 238 | kwargs["shuffle"] = False 239 | u_loader = BipartiteNeighborSampler( 240 | adj_all, 241 | n_layers=1, 242 | base="u", 243 | batch_size=batch_sizes[0], 244 | n_other_node=1, 245 | num_neighbors_u=-1, 246 | num_neighbors_v=1, 247 | **kwargs, 248 | ) 249 | v_loader = BipartiteNeighborSampler( 250 | adj_all, 251 | n_layers=1, 252 | base="v", 253 | batch_size=batch_sizes[1], 254 | n_other_node=1, 255 | num_neighbors_u=1, 256 | num_neighbors_v=-1, 257 | **kwargs, 258 | ) 259 | e_loader = EdgeLoader(adj_all, batch_size=batch_sizes[2], **kwargs) 260 | 261 | u_mlp_loader = torch.utils.data.DataLoader( 262 | torch.arange(xu_all.shape[0]), batch_size=batch_sizes[0], **kwargs 263 | ) 264 | v_mlp_loader = torch.utils.data.DataLoader( 265 | torch.arange(xv_all.shape[0]), batch_size=batch_sizes[1], **kwargs 266 | ) 267 | 268 | epred_loader = torch.utils.data.DataLoader( 269 | torch.arange(edge_pred_samples.nnz()), batch_size=batch_sizes[2], **kwargs 270 | ) 271 | 272 | total_iter = ( 273 | (len(u_loader) + len(v_loader)) 274 | * (self.n_layers_encoder + self.n_layers_decoder) 275 | + len(e_loader) * (self.n_layers_encoder + self.n_layers_decoder - 1) 276 | + (len(u_mlp_loader) + len(v_mlp_loader)) * self.n_layers_mlp 277 | + len(epred_loader) 278 | ) 279 | if progress_bar: 280 | pbar = tqdm(total=total_iter, leave=False) 281 | pbar.set_description(f"Evaluation") 282 | 283 | # encoder 284 | for i, conv in enumerate(self.encoder_convs): 285 | 286 | ## next u nodes 287 | xu_list = [] 288 | for _, _, adjacency, _ in u_loader: 289 | out = self.apply_conv( 290 | conv.v2u_conv, adjacency.adj_v2u, xu_all, xv_all, xe_all, device 291 | ) 292 | if i != self.n_layers_encoder - 1: 293 | out = F.relu(out) 294 | xu_list.append(out.cpu()) 295 | if progress_bar: 296 | pbar.update(1) 297 | xu_all_next = torch.cat(xu_list, dim=0) 298 | 299 | ## next v nodes 300 | xv_list = [] 301 | for _, _, adjacency, _ in v_loader: 302 | out = self.apply_conv( 303 | conv.u2v_conv, adjacency.adj_u2v, xu_all, xv_all, xe_all, device 304 | ) 305 | if i != self.n_layers_encoder - 1: 306 | out = F.relu(out) 307 | xv_list.append(out.cpu()) 308 | if progress_bar: 309 | pbar.update(1) 310 | xv_all_next = torch.cat(xv_list, dim=0) 311 | 312 | ## next edge 313 | if i != self.n_layers_encoder - 1: 314 | xe_list = [] 315 | for adj_e in e_loader: 316 | out = self.apply_conv( 317 | conv.e_conv, adj_e, xu_all, xv_all, xe_all, device 318 | ) 319 | out = F.relu(out) 320 | xe_list.append(out.cpu()) 321 | if progress_bar: 322 | pbar.update(1) 323 | xe_all_next = torch.cat(xe_list, dim=0) 324 | else: 325 | xe_all_next = None 326 | 327 | xu_all = xu_all_next 328 | xv_all = xv_all_next 329 | xe_all = xe_all_next 330 | 331 | # get latent vars 332 | zu_all, zv_all = xu_all, xv_all 333 | 334 | # feature decoder 335 | for i, conv in enumerate(self.decoder_convs): 336 | 337 | ## next u nodes 338 | xu_list = [] 339 | for _, _, adjacency, _ in u_loader: 340 | out = self.apply_conv( 341 | conv.v2u_conv, adjacency.adj_v2u, xu_all, xv_all, xe_all, device 342 | ) 343 | if i != self.n_layers_decoder - 1: 344 | out = F.relu(out) 345 | xu_list.append(out.cpu()) 346 | if progress_bar: 347 | pbar.update(1) 348 | xu_all_next = torch.cat(xu_list, dim=0) 349 | 350 | ## next v nodes 351 | xv_list = [] 352 | for _, _, adjacency, _ in v_loader: 353 | out = self.apply_conv( 354 | conv.u2v_conv, adjacency.adj_u2v, xu_all, xv_all, xe_all, device 355 | ) 356 | if i != self.n_layers_decoder - 1: 357 | out = F.relu(out) 358 | xv_list.append(out.cpu()) 359 | if progress_bar: 360 | pbar.update(1) 361 | xv_all_next = torch.cat(xv_list, dim=0) 362 | 363 | ## next edge 364 | xe_list = [] 365 | for adj_e in e_loader: 366 | out = self.apply_conv( 367 | conv.e_conv, adj_e, xu_all, xv_all, xe_all, device 368 | ) 369 | if i != self.n_layers_decoder - 1: 370 | out = F.relu(out) 371 | xe_list.append(out.cpu()) 372 | if progress_bar: 373 | pbar.update(1) 374 | xe_all_next = torch.cat(xe_list, dim=0) 375 | 376 | xu_all = xu_all_next 377 | xv_all = xv_all_next 378 | xe_all = xe_all_next 379 | 380 | # structure decoder 381 | zu2_all, zv2_all = zu_all, zv_all 382 | for i, layer in enumerate(self.u_mlp_layers): 383 | zu2_list = [] 384 | for batch in u_mlp_loader: 385 | out = layer(zu2_all[batch].to(device)) 386 | if i != self.n_layers_mlp - 1: 387 | out = F.relu(out) 388 | zu2_list.append(out.cpu()) 389 | if progress_bar: 390 | pbar.update(1) 391 | zu2_all = torch.cat(zu2_list, dim=0) 392 | 393 | for i, layer in enumerate(self.v_mlp_layers): 394 | zv2_list = [] 395 | for batch in v_mlp_loader: 396 | out = layer(zv2_all[batch].to(device)) 397 | if i != self.n_layers_mlp - 1: 398 | out = F.relu(out) 399 | zv2_list.append(out.cpu()) 400 | if progress_bar: 401 | pbar.update(1) 402 | zv2_all = torch.cat(zv2_list, dim=0) 403 | 404 | eprob_list = [] 405 | for batch in epred_loader: 406 | zu2_edge = zu2_all[edge_pred_samples.storage.row()[batch]].to(device) 407 | zv2_edge = zv2_all[edge_pred_samples.storage.col()[batch]].to(device) 408 | out = torch.sigmoid(torch.sum(zu2_edge * zv2_edge, dim=1)) 409 | eprob_list.append(out.cpu()) 410 | if progress_bar: 411 | pbar.update(1) 412 | eprob_all = torch.cat(eprob_list, dim=0) 413 | 414 | # collect results 415 | result = { 416 | "xu": xu_all, 417 | "xv": xv_all, 418 | "xe": xe_all, 419 | "zu": zu_all, 420 | "zv": zv_all, 421 | "eprob": eprob_all, 422 | } 423 | 424 | if progress_bar: 425 | pbar.close() 426 | 427 | return result 428 | -------------------------------------------------------------------------------- /models/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | import math 8 | 9 | from torch_sparse import SparseTensor 10 | from torch_sparse.storage import SparseStorage 11 | 12 | from typing import List, NamedTuple, Optional, Tuple, Union 13 | 14 | from utils.sparse_combine import spadd 15 | from utils.sprand import sprand 16 | 17 | 18 | class EdgePredictionSampler: 19 | def __init__( 20 | self, 21 | adj: SparseTensor, 22 | n_random: Optional[int] = None, 23 | mult: Optional[float] = 2.0, 24 | ): 25 | self.adj = adj 26 | 27 | if n_random is None: 28 | n_pos = adj.nnz() 29 | n_random = mult * n_pos 30 | 31 | self.adj = adj 32 | self.n_random = n_random 33 | 34 | def sample(self): 35 | rnd_samples = sprand(self.adj.sparse_sizes(), self.n_random) 36 | rnd_samples.fill_value_(-1) 37 | rnd_samples = rnd_samples.to(self.adj.device()) 38 | 39 | pos_samples = self.adj.fill_value(2) 40 | 41 | samples = spadd(rnd_samples, pos_samples) 42 | samples.set_value_( 43 | torch.minimum( 44 | samples.storage.value(), torch.ones_like(samples.storage.value()) 45 | ), 46 | layout="coo", 47 | ) 48 | 49 | return samples 50 | 51 | 52 | ### REGION Neighbor Sampling 53 | def sample_v_given_u( 54 | adj: SparseTensor, 55 | u_indices: Tensor, 56 | prev_v: Tensor, 57 | num_neighbors: int, 58 | replace=False, 59 | ) -> Tuple[SparseTensor, Tensor]: 60 | 61 | # to homogenous adjacency 62 | nu, nv = adj.sparse_sizes() 63 | adj_h = SparseTensor( 64 | row=adj.storage.row(), 65 | col=adj.storage.col() + nu, 66 | value=adj.storage.value(), 67 | sparse_sizes=(nu + nv, nu + nv), 68 | ) 69 | 70 | res_adj_h, res_id = adj_h.sample_adj( 71 | torch.cat([u_indices, prev_v + nu]), 72 | num_neighbors=num_neighbors, 73 | replace=replace, 74 | ) 75 | 76 | ni = len(u_indices) 77 | v_indices = res_id[ni:] - nu 78 | res_adj = res_adj_h[:ni, ni:] 79 | 80 | return res_adj, v_indices 81 | 82 | 83 | def sample_u_given_v( 84 | adj: SparseTensor, 85 | v_indices: Tensor, 86 | prev_u: Tensor, 87 | num_neighbors: int, 88 | replace=False, 89 | ) -> Tuple[SparseTensor, Tensor]: 90 | 91 | # to homogenous adjacency 92 | res_adj_t, u_indices = sample_v_given_u( 93 | adj.t(), v_indices, prev_u, num_neighbors=num_neighbors, replace=replace 94 | ) 95 | 96 | return res_adj_t.t(), u_indices 97 | 98 | 99 | class DirectedAdj(NamedTuple): 100 | adj: SparseTensor 101 | u_id: Tensor 102 | v_id: Tensor 103 | e_id: Optional[Tensor] 104 | size: Tuple[int, int] 105 | flow: str 106 | 107 | def to(self, *args, **kwargs): 108 | adj = self.adj.to(*args, **kwargs) 109 | u_id = self.u_id.to(*args, **kwargs) 110 | v_id = self.v_id.to(*args, **kwargs) 111 | e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None 112 | return DirectedAdj(adj, u_id, v_id, e_id, self.size, self.flow) 113 | 114 | 115 | class BEANAdjacency(NamedTuple): 116 | adj_v2u: DirectedAdj 117 | adj_u2v: DirectedAdj 118 | adj_e: Optional[DirectedAdj] 119 | 120 | def to(self, *args, **kwargs): 121 | adj_v2u = self.adj_v2u.to(*args, **kwargs) 122 | adj_u2v = self.adj_u2v.to(*args, **kwargs) 123 | adj_e = None 124 | if self.adj_e is not None: 125 | adj_e = self.adj_e.to(*args, **kwargs) 126 | return BEANAdjacency(adj_v2u, adj_u2v, adj_e) 127 | 128 | 129 | class BipartiteNeighborSampler(torch.utils.data.DataLoader): 130 | def __init__( 131 | self, 132 | adj: SparseTensor, 133 | n_layers: int, 134 | num_neighbors_u: Union[int, List[int]], 135 | num_neighbors_v: Union[int, List[int]], 136 | base: str = "u", 137 | n_other_node: int = -1, 138 | **kwargs 139 | ): 140 | 141 | adj = adj.to("cpu") 142 | 143 | if "collate_fn" in kwargs: 144 | del kwargs["collate_fn"] 145 | 146 | self.adj = adj 147 | self.n_layers = n_layers 148 | self.base = base 149 | self.n_other_node = n_other_node 150 | 151 | if isinstance(num_neighbors_u, int): 152 | num_neighbors_u = [num_neighbors_u for _ in range(n_layers)] 153 | if isinstance(num_neighbors_v, int): 154 | num_neighbors_v = [num_neighbors_v for _ in range(n_layers)] 155 | self.num_neighbors_u = num_neighbors_u 156 | self.num_neighbors_v = num_neighbors_v 157 | 158 | if base == "u": # start from u 159 | item_idx = torch.arange(adj.sparse_size(0)) 160 | elif base == "v": # start from v instead 161 | item_idx = torch.arange(adj.sparse_size(1)) 162 | elif base == "e": # start from e instead 163 | item_idx = torch.arange(adj.nnz()) 164 | else: # start from u default 165 | item_idx = torch.arange(adj.sparse_size(0)) 166 | 167 | value = torch.arange(adj.nnz()) 168 | adj = adj.set_value(value, layout="coo") 169 | self.__val__ = adj.storage.value() 170 | 171 | # transpose of adjacency 172 | self.adj = adj 173 | self.adj_t = adj.t() 174 | 175 | # homogenous graph adjacency matrix 176 | self.nu, self.nv = self.adj.sparse_sizes() 177 | self.adj_homogen = SparseTensor( 178 | row=self.adj.storage.row(), 179 | col=self.adj.storage.col() + self.nu, 180 | value=self.adj.storage.value(), 181 | sparse_sizes=(self.nu + self.nv, self.nu + self.nv), 182 | ) 183 | self.adj_t_homogen = SparseTensor( 184 | row=self.adj_t.storage.row(), 185 | col=self.adj_t.storage.col() + self.nv, 186 | value=self.adj_t.storage.value(), 187 | sparse_sizes=(self.nu + self.nv, self.nu + self.nv), 188 | ) 189 | 190 | super(BipartiteNeighborSampler, self).__init__( 191 | item_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs 192 | ) 193 | 194 | def sample_v_given_u( 195 | self, u_indices: Tensor, prev_v: Tensor, num_neighbors: int 196 | ) -> Tuple[SparseTensor, Tensor]: 197 | 198 | res_adj_h, res_id = self.adj_homogen.sample_adj( 199 | torch.cat([u_indices, prev_v + self.nu]), 200 | num_neighbors=num_neighbors, 201 | replace=False, 202 | ) 203 | 204 | ni = len(u_indices) 205 | v_indices = res_id[ni:] - self.nu 206 | res_adj = res_adj_h[:ni, ni:] 207 | 208 | return res_adj, v_indices 209 | 210 | def sample_u_given_v( 211 | self, v_indices: Tensor, prev_u: Tensor, num_neighbors: int 212 | ) -> Tuple[SparseTensor, Tensor]: 213 | 214 | # start = time.time() 215 | res_adj_h, res_id = self.adj_t_homogen.sample_adj( 216 | torch.cat([v_indices, prev_u + self.nv]), 217 | num_neighbors=num_neighbors, 218 | replace=False, 219 | ) 220 | # print(f"adjoint sampling : {time.time() - start} s") 221 | 222 | ni = len(v_indices) 223 | u_indices = res_id[ni:] - self.nv 224 | res_adj = res_adj_h[:ni, ni:] 225 | 226 | return res_adj.t(), u_indices 227 | 228 | def adjacency_from_samples( 229 | self, adj: SparseTensor, u_id: Tensor, v_id: Tensor, flow: str 230 | ) -> DirectedAdj: 231 | 232 | e_id = adj.storage.value() 233 | size = adj.sparse_sizes() 234 | if self.__val__ is not None: 235 | adj.set_value_(self.__val__[e_id], layout="coo") 236 | 237 | return DirectedAdj(adj, u_id, v_id, e_id, size, flow) 238 | 239 | def combine_adjacency( 240 | self, v2u_adj: SparseTensor, u2v_adj: SparseTensor, e_adj: SparseTensor 241 | ) -> SparseTensor: 242 | 243 | # start = time.time() 244 | nu = u2v_adj.sparse_size(0) 245 | nv = v2u_adj.sparse_size(1) 246 | 247 | row = torch.cat( 248 | [e_adj.storage.row(), v2u_adj.storage.row(), u2v_adj.storage.row()], dim=-1 249 | ) 250 | col = torch.cat( 251 | [e_adj.storage.col(), v2u_adj.storage.col(), u2v_adj.storage.col()], dim=-1 252 | ) 253 | value = torch.cat( 254 | [e_adj.storage.value(), v2u_adj.storage.value(), u2v_adj.storage.value()], 255 | dim=0, 256 | ) 257 | fl = torch.cat( 258 | [ 259 | torch.ones(e_adj.nnz()), 260 | 2 * torch.ones(v2u_adj.nnz()), 261 | 4 * torch.ones(u2v_adj.nnz()), 262 | ] 263 | ) 264 | 265 | storage = SparseStorage( 266 | row=row, col=col, value=value, sparse_sizes=(nu, nv), is_sorted=False 267 | ) 268 | storage = storage.coalesce(reduce="mean") 269 | 270 | fl_storage = SparseStorage( 271 | row=row, col=col, value=fl, sparse_sizes=(nu, nv), is_sorted=False 272 | ) 273 | fl_storage = fl_storage.coalesce(reduce="sum") 274 | 275 | res = SparseTensor.from_storage(storage) 276 | flag = SparseTensor.from_storage(fl_storage) 277 | 278 | # print(f"combine adj : {time.time() - start} s") 279 | 280 | return res, flag 281 | 282 | def sample(self, batch): 283 | 284 | # start = time.time() 285 | 286 | if not isinstance(batch, Tensor): 287 | batch = torch.tensor(batch) 288 | 289 | batch_size: int = len(batch) 290 | 291 | # calculate batch_size for another node 292 | if self.n_other_node == -1 and self.base in ["u", "v"]: 293 | # do proportional 294 | nu, nv = self.adj.sparse_sizes() 295 | if self.base == "u": 296 | self.n_other_node = int(math.ceil((nv / nu) * batch_size)) 297 | elif self.base == "v": 298 | self.n_other_node = int(math.ceil((nu / nv) * batch_size)) 299 | 300 | ## get the other indices 301 | empty_list = torch.tensor([], dtype=torch.long) 302 | if self.base == "u": 303 | # get the base node for v 304 | u_indices = batch 305 | res_adj, res_id = self.sample_v_given_u( 306 | u_indices, empty_list, num_neighbors=self.num_neighbors_u[0] 307 | ) 308 | rand_id = torch.randperm(len(res_id))[: self.n_other_node] 309 | v_indices = res_id[rand_id] 310 | e_adj = res_adj[:, rand_id] 311 | elif self.base == "v": 312 | # get the base node for u 313 | v_indices = batch 314 | res_adj, res_id = self.sample_u_given_v( 315 | v_indices, empty_list, num_neighbors=self.num_neighbors_v[0] 316 | ) 317 | rand_id = torch.randperm(len(res_id))[: self.n_other_node] 318 | u_indices = res_id[rand_id] 319 | e_adj = res_adj[rand_id, :] 320 | elif self.base == "e": 321 | # get the base node for u and v 322 | row = self.adj.storage.row()[batch] 323 | col = self.adj.storage.col()[batch] 324 | unique_row, invidx_row = torch.unique(row, return_inverse=True) 325 | unique_col, invidx_col = torch.unique(col, return_inverse=True) 326 | 327 | reindex_row_id = torch.arange(len(unique_row)) 328 | reindex_col_id = torch.arange(len(unique_col)) 329 | reindex_row = reindex_row_id[invidx_row] 330 | reindex_col = reindex_col_id[invidx_col] 331 | 332 | e_adj = SparseTensor(row=reindex_row, col=reindex_col, value=batch) 333 | e_indices = batch 334 | u_indices = unique_row 335 | v_indices = unique_col 336 | 337 | # init results 338 | adjacencies = [] 339 | e_flags = [] 340 | 341 | ## for subsequent layers 342 | for i in range(self.n_layers): 343 | 344 | # v -> u 345 | u_adj, next_v_indices = self.sample_v_given_u( 346 | u_indices, prev_v=v_indices, num_neighbors=self.num_neighbors_u[i] 347 | ) 348 | dir_adj_v2u = self.adjacency_from_samples( 349 | u_adj, u_indices, next_v_indices, "v->u" 350 | ) 351 | 352 | # u -> v 353 | v_adj, next_u_indices = self.sample_u_given_v( 354 | v_indices, prev_u=u_indices, num_neighbors=self.num_neighbors_v[i] 355 | ) 356 | dir_adj_u2v = self.adjacency_from_samples( 357 | v_adj, next_u_indices, v_indices, "u->v" 358 | ) 359 | 360 | # u -> e <- v 361 | dir_adj_e = self.adjacency_from_samples( 362 | e_adj, u_indices, v_indices, "u->e<-v" 363 | ) 364 | 365 | # add them to the list 366 | adjacencies.append(BEANAdjacency(dir_adj_v2u, dir_adj_u2v, dir_adj_e)) 367 | 368 | # for next iter 369 | e_adj, e_flag = self.combine_adjacency( 370 | v2u_adj=u_adj, u2v_adj=v_adj, e_adj=e_adj 371 | ) 372 | u_indices = next_u_indices 373 | v_indices = next_v_indices 374 | e_flags.append(e_flag) 375 | 376 | # flip the order 377 | adjacencies = adjacencies[0] if len(adjacencies) == 1 else adjacencies[::-1] 378 | e_flags = e_flags[0] if len(e_flags) == 1 else e_flags[::-1] 379 | 380 | # get e_indices 381 | e_indices = e_adj.storage.value() 382 | 383 | # print(f"sampling : {time.time() - start} s") 384 | 385 | return batch_size, (u_indices, v_indices, e_indices), adjacencies, e_flags 386 | 387 | 388 | class EdgeLoader(torch.utils.data.DataLoader): 389 | def __init__(self, adj: SparseTensor, **kwargs): 390 | 391 | edge_idx = torch.arange(adj.nnz()) 392 | self.adj = adj 393 | 394 | super().__init__(edge_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs) 395 | 396 | def sample(self, batch): 397 | 398 | if not isinstance(batch, Tensor): 399 | batch = torch.tensor(batch) 400 | 401 | row = self.adj.storage.row()[batch] 402 | col = self.adj.storage.col()[batch] 403 | if self.adj.storage.has_value(): 404 | val = self.adj.storage.col()[batch] 405 | else: 406 | val = batch 407 | 408 | # get unique row, col & idx 409 | unique_row, invidx_row = torch.unique(row, return_inverse=True) 410 | unique_col, invidx_col = torch.unique(col, return_inverse=True) 411 | 412 | reindex_row_id = torch.arange(len(unique_row)) 413 | reindex_col_id = torch.arange(len(unique_col)) 414 | 415 | reindex_row = reindex_row_id[invidx_row] 416 | reindex_col = reindex_col_id[invidx_col] 417 | 418 | adj = SparseTensor(row=reindex_row, col=reindex_col, value=val) 419 | e_id = batch 420 | u_id = unique_row 421 | v_id = unique_col 422 | 423 | adj_e = DirectedAdj(adj, u_id, v_id, e_id, adj.sparse_sizes(), "u->e<-v") 424 | 425 | return adj_e 426 | -------------------------------------------------------------------------------- /models/score.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | from typing import Dict 5 | import torch 6 | from torch import Tensor 7 | 8 | from torch_scatter import scatter 9 | from torch_sparse import SparseTensor 10 | 11 | from sklearn.metrics import ( 12 | accuracy_score, 13 | f1_score, 14 | precision_score, 15 | recall_score, 16 | roc_curve, 17 | precision_recall_curve, 18 | auc, 19 | ) 20 | import pandas as pd 21 | 22 | 23 | def compute_anomaly_score( 24 | xu: Tensor, 25 | xv: Tensor, 26 | xe: Tensor, 27 | adj: SparseTensor, 28 | edge_pred_samples: SparseTensor, 29 | out: Dict[str, Tensor], 30 | xe_loss_weight: float = 1.0, 31 | structure_loss_weight: float = 1.0, 32 | ) -> Dict[str, Tensor]: 33 | 34 | # node error, use RMSE instead of MSE 35 | xu_error = torch.sqrt(torch.mean((xu - out["xu"]) ** 2, dim=1)) 36 | xv_error = torch.sqrt(torch.mean((xv - out["xv"]) ** 2, dim=1)) 37 | 38 | # edge error, use RMSE instead of MSE 39 | xe_error = torch.sqrt(torch.mean((xe - out["xe"]) ** 2, dim=1)) 40 | 41 | # edge prediction cross entropy 42 | edge_ce = -torch.log(out["eprob"][edge_pred_samples.storage.value() > 0] + 1e-12) 43 | 44 | # edge score 45 | e_score = xe_loss_weight * xe_error + structure_loss_weight * edge_ce 46 | 47 | # edge score 48 | u_score_edge_max = xu_error + scatter( 49 | e_score, adj.storage.row(), dim=0, reduce="max" 50 | ) 51 | v_score_edge_max = xv_error + scatter( 52 | e_score, adj.storage.col(), dim=0, reduce="max" 53 | ) 54 | u_score_edge_mean = xu_error + scatter( 55 | e_score, adj.storage.row(), dim=0, reduce="mean" 56 | ) 57 | v_score_edge_mean = xv_error + scatter( 58 | e_score, adj.storage.col(), dim=0, reduce="mean" 59 | ) 60 | u_score_edge_sum = xu_error + scatter( 61 | e_score, adj.storage.row(), dim=0, reduce="sum" 62 | ) 63 | v_score_edge_sum = xv_error + scatter( 64 | e_score, adj.storage.col(), dim=0, reduce="sum" 65 | ) 66 | 67 | anomaly_score = { 68 | "xu_error": xu_error, 69 | "xv_error": xv_error, 70 | "xe_error": xe_error, 71 | "edge_ce": edge_ce, 72 | "e_score": e_score, 73 | "u_score_edge_max": u_score_edge_max, 74 | "u_score_edge_mean": u_score_edge_mean, 75 | "u_score_edge_sum": u_score_edge_sum, 76 | "v_score_edge_max": v_score_edge_max, 77 | "v_score_edge_mean": v_score_edge_mean, 78 | "v_score_edge_sum": v_score_edge_sum, 79 | } 80 | 81 | return anomaly_score 82 | 83 | 84 | def edge_prediction_metric( 85 | edge_pred_samples: SparseTensor, edge_prob: Tensor 86 | ) -> Dict[str, float]: 87 | 88 | edge_pred = (edge_prob >= 0.5).int().cpu().numpy() 89 | edge_gt = (edge_pred_samples.storage.value() > 0).int().cpu().numpy() 90 | 91 | acc = accuracy_score(edge_gt, edge_pred) 92 | prec = precision_score(edge_gt, edge_pred) 93 | rec = recall_score(edge_gt, edge_pred) 94 | f1 = f1_score(edge_gt, edge_pred) 95 | 96 | result = {"acc": acc, "prec": prec, "rec": rec, "f1": f1} 97 | return result 98 | 99 | 100 | def compute_evaluation_metrics( 101 | anomaly_score: Dict[str, Tensor], yu: Tensor, yv: Tensor, ye: Tensor, agg="max" 102 | ): 103 | 104 | # node u 105 | u_roc_curve = roc_curve( 106 | yu.cpu().numpy(), anomaly_score[f"u_score_edge_{agg}"].cpu().numpy() 107 | ) 108 | u_pr_curve = precision_recall_curve( 109 | yu.cpu().numpy(), anomaly_score[f"u_score_edge_{agg}"].cpu().numpy() 110 | ) 111 | u_roc_auc = auc(u_roc_curve[0], u_roc_curve[1]) 112 | u_pr_auc = auc(u_pr_curve[1], u_pr_curve[0]) 113 | 114 | # node v 115 | v_roc_curve = roc_curve( 116 | yv.cpu().numpy(), anomaly_score[f"v_score_edge_{agg}"].cpu().numpy() 117 | ) 118 | v_pr_curve = precision_recall_curve( 119 | yv.cpu().numpy(), anomaly_score[f"v_score_edge_{agg}"].cpu().numpy() 120 | ) 121 | v_roc_auc = auc(v_roc_curve[0], v_roc_curve[1]) 122 | v_pr_auc = auc(v_pr_curve[1], v_pr_curve[0]) 123 | 124 | # nedge 125 | e_roc_curve = roc_curve(ye.cpu().numpy(), anomaly_score["xe_error"].cpu().numpy()) 126 | e_pr_curve = precision_recall_curve( 127 | ye.cpu().numpy(), anomaly_score["xe_error"].cpu().numpy() 128 | ) 129 | e_roc_auc = auc(e_roc_curve[0], e_roc_curve[1]) 130 | e_pr_auc = auc(e_pr_curve[1], e_pr_curve[0]) 131 | 132 | metrics = { 133 | "u_roc_curve": u_roc_curve, 134 | "u_pr_curve": u_pr_curve, 135 | "u_roc_auc": u_roc_auc, 136 | "u_pr_auc": u_pr_auc, 137 | "v_roc_curve": v_roc_curve, 138 | "v_pr_curve": v_pr_curve, 139 | "v_roc_auc": v_roc_auc, 140 | "v_pr_auc": v_pr_auc, 141 | "e_roc_curve": e_roc_curve, 142 | "e_pr_curve": e_pr_curve, 143 | "e_roc_auc": e_roc_auc, 144 | "e_pr_auc": e_pr_auc, 145 | } 146 | 147 | return metrics 148 | 149 | 150 | def attach_anomaly_score( 151 | anomaly_score: Dict[str, Tensor], 152 | dfu_id: pd.DataFrame, 153 | dfv_id: pd.DataFrame, 154 | dfe_id: pd.DataFrame, 155 | ): 156 | 157 | dfu_id = dfu_id.assign( 158 | xu_error=anomaly_score["xu_error"].cpu().numpy(), 159 | u_score_edge_max=anomaly_score["u_score_edge_max"].cpu().numpy(), 160 | u_score_edge_mean=anomaly_score["u_score_edge_mean"].cpu().numpy(), 161 | ) 162 | 163 | dfv_id = dfv_id.assign( 164 | xv_error=anomaly_score["xv_error"].cpu().numpy(), 165 | v_score_edge_max=anomaly_score["v_score_edge_max"].cpu().numpy(), 166 | v_score_edge_mean=anomaly_score["v_score_edge_mean"].cpu().numpy(), 167 | ) 168 | 169 | dfe_id = dfe_id.assign( 170 | xe_error=anomaly_score["xe_error"].cpu().numpy(), 171 | edge_ce=anomaly_score["edge_ce"].cpu().numpy(), 172 | e_score=anomaly_score["e_score"].cpu().numpy(), 173 | ) 174 | 175 | return dfu_id, dfv_id, dfe_id 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.6 2 | pandas==1.1.5 3 | pygod==0.3.0 4 | scikit_learn==1.0.2 5 | scipy==1.7.2 6 | sentence_transformers==2.2.2 7 | torch==1.9.0 8 | torch_geometric==2.0.3 9 | torch_scatter==2.0.8 10 | torch_sparse==0.6.12 11 | tqdm==4.62.3 12 | tensorboard==2.9.1 -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /storage/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /train_full_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import sys 5 | 6 | from data_finefoods import load_graph 7 | from models.score import compute_evaluation_metrics 8 | 9 | import time 10 | from tqdm import tqdm 11 | import argparse 12 | import os 13 | 14 | from torch.utils.tensorboard import SummaryWriter 15 | import datetime 16 | 17 | import torch 18 | 19 | from models.data import BipartiteData 20 | from models.net import GraphBEAN 21 | from models.sampler import EdgePredictionSampler 22 | from models.loss import reconstruction_loss 23 | from models.score import compute_anomaly_score, edge_prediction_metric 24 | 25 | from utils.seed import seed_all 26 | 27 | # %% args 28 | 29 | parser = argparse.ArgumentParser(description="GraphBEAN") 30 | parser.add_argument("--name", type=str, default="wikipedia_anomaly", help="name") 31 | parser.add_argument( 32 | "--key", type=str, default="graph_anomaly_list", help="key to the data" 33 | ) 34 | parser.add_argument("--id", type=int, default=0, help="id to the data") 35 | parser.add_argument("--n-epoch", type=int, default=200, help="number of epoch") 36 | parser.add_argument( 37 | "--scheduler-milestones", 38 | nargs="+", 39 | type=int, 40 | default=[], 41 | help="scheduler milestone", 42 | ) 43 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") 44 | parser.add_argument( 45 | "--score-agg", type=str, default="max", help="aggregation for node anomaly score" 46 | ) 47 | parser.add_argument("--eta", type=float, default=0.2, help="structure loss weight") 48 | 49 | args1 = vars(parser.parse_args()) 50 | 51 | args2 = { 52 | "hidden_channels": 32, 53 | "latent_channels_u": 32, 54 | "latent_channels_v": 32, 55 | "edge_pred_latent": 32, 56 | "n_layers_encoder": 2, 57 | "n_layers_decoder": 2, 58 | "n_layers_mlp": 2, 59 | "dropout_prob": 0.0, 60 | "gamma": 0.2, 61 | "xe_loss_weight": 1.0, 62 | "structure_loss_weight": args1["eta"], 63 | "structure_loss_weight_anomaly_score": args1["eta"], 64 | "iter_check": 10, 65 | "seed": 0, 66 | "neg_sampler_mult": 5, 67 | "k_check": 15, 68 | "tensorboard": False, 69 | "progress_bar": True, 70 | } 71 | 72 | args = {**args1, **args2} 73 | 74 | seed_all(args["seed"]) 75 | 76 | result_dir = "results/" 77 | 78 | 79 | # %% data 80 | data = load_graph(args["name"], args["key"], args["id"]) 81 | 82 | u_ch = data.xu.shape[1] 83 | v_ch = data.xv.shape[1] 84 | e_ch = data.xe.shape[1] 85 | 86 | print( 87 | f"Data dimension: U node = {data.xu.shape}; V node = {data.xv.shape}; E edge = {data.xe.shape}; \n" 88 | ) 89 | 90 | # %% model 91 | 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 93 | model = GraphBEAN( 94 | in_channels=(u_ch, v_ch, e_ch), 95 | hidden_channels=args["hidden_channels"], 96 | latent_channels=(args["latent_channels_u"], args["latent_channels_v"]), 97 | edge_pred_latent=args["edge_pred_latent"], 98 | n_layers_encoder=args["n_layers_encoder"], 99 | n_layers_decoder=args["n_layers_decoder"], 100 | n_layers_mlp=args["n_layers_mlp"], 101 | dropout_prob=args["dropout_prob"], 102 | ) 103 | 104 | model = model.to(device) 105 | optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"]) 106 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 107 | optimizer, milestones=args["scheduler_milestones"], gamma=args["gamma"] 108 | ) 109 | 110 | xu, xv = data.xu.to(device), data.xv.to(device) 111 | xe, adj = data.xe.to(device), data.adj.to(device) 112 | yu, yv, ye = data.yu.to(device), data.yv.to(device), data.ye.to(device) 113 | 114 | 115 | # sampler 116 | sampler = EdgePredictionSampler(adj, mult=args["neg_sampler_mult"]) 117 | 118 | print(args) 119 | print() 120 | 121 | # %% train 122 | def train(epoch): 123 | 124 | model.train() 125 | 126 | edge_pred_samples = sampler.sample() 127 | 128 | optimizer.zero_grad() 129 | out = model(xu, xv, xe, adj, edge_pred_samples) 130 | 131 | loss, loss_component = reconstruction_loss( 132 | xu, 133 | xv, 134 | xe, 135 | adj, 136 | edge_pred_samples, 137 | out, 138 | xe_loss_weight=args["xe_loss_weight"], 139 | structure_loss_weight=args["structure_loss_weight"], 140 | ) 141 | 142 | loss.backward() 143 | optimizer.step() 144 | scheduler.step() 145 | 146 | epred_metric = edge_prediction_metric(edge_pred_samples, out["eprob"]) 147 | 148 | return loss, loss_component, epred_metric 149 | 150 | 151 | # %% evaluate and store 152 | def eval(epoch): 153 | 154 | # model.eval() 155 | 156 | start = time.time() 157 | 158 | # negative sampling 159 | edge_pred_samples = sampler.sample() 160 | 161 | with torch.no_grad(): 162 | 163 | out = model(xu, xv, xe, adj, edge_pred_samples) 164 | 165 | loss, loss_component = reconstruction_loss( 166 | xu, 167 | xv, 168 | xe, 169 | adj, 170 | edge_pred_samples, 171 | out, 172 | xe_loss_weight=args["xe_loss_weight"], 173 | structure_loss_weight=args["structure_loss_weight"], 174 | ) 175 | 176 | epred_metric = edge_prediction_metric(edge_pred_samples, out["eprob"]) 177 | 178 | anomaly_score = compute_anomaly_score( 179 | xu, 180 | xv, 181 | xe, 182 | adj, 183 | edge_pred_samples, 184 | out, 185 | xe_loss_weight=args["xe_loss_weight"], 186 | structure_loss_weight=args["structure_loss_weight_anomaly_score"], 187 | ) 188 | 189 | eval_metrics = compute_evaluation_metrics( 190 | anomaly_score, yu, yv, ye, agg=args["score_agg"] 191 | ) 192 | 193 | elapsed = time.time() - start 194 | 195 | print( 196 | f"Eval, loss: {loss:.4f}, " 197 | + f"u auc-roc: {eval_metrics['u_roc_auc']:.4f}, v auc-roc: {eval_metrics['v_roc_auc']:.4f}, e auc-roc: {eval_metrics['e_roc_auc']:.4f}, " 198 | + f"u auc-pr {eval_metrics['u_pr_auc']:.4f}, v auc-pr {eval_metrics['v_pr_auc']:.4f}, e auc-pr {eval_metrics['e_pr_auc']:.4f} " 199 | + f"> {elapsed:.2f}s" 200 | ) 201 | 202 | if args["tensorboard"]: 203 | tb.add_scalar("loss", loss, epoch) 204 | tb.add_scalar("u_roc_auc", eval_metrics["u_roc_auc"], epoch) 205 | tb.add_scalar("u_pr_auc", eval_metrics["u_pr_auc"], epoch) 206 | tb.add_scalar("v_roc_auc", eval_metrics["v_roc_auc"], epoch) 207 | tb.add_scalar("v_pr_auc", eval_metrics["v_pr_auc"], epoch) 208 | tb.add_scalar("e_roc_auc", eval_metrics["e_roc_auc"], epoch) 209 | tb.add_scalar("e_pr_auc", eval_metrics["e_pr_auc"], epoch) 210 | 211 | model_stored = { 212 | "args": args, 213 | "loss": loss, 214 | "loss_component": loss_component, 215 | "epred_metric": epred_metric, 216 | "eval_metrics": eval_metrics, 217 | "loss_hist": loss_hist, 218 | "loss_component_hist": loss_component_hist, 219 | "epred_metric_hist": epred_metric_hist, 220 | "state_dict": model.state_dict(), 221 | "optimizer_state_dict": optimizer.state_dict(), 222 | } 223 | output_stored = {"args": args, "out": out, "anomaly_score": anomaly_score} 224 | 225 | print("Saving current results...") 226 | torch.save( 227 | model_stored, 228 | os.path.join( 229 | result_dir, 230 | f"graphbean-{args['name']}-{args['id']}-eta-{args['eta']}-structure-model.th", 231 | ), 232 | ) 233 | torch.save( 234 | output_stored, 235 | os.path.join( 236 | result_dir, 237 | f"graphbean-{args['name']}-{args['id']}-eta-{args['eta']}-structure-output.th", 238 | ), 239 | ) 240 | 241 | return loss, loss_component, epred_metric 242 | 243 | 244 | # %% run training 245 | loss_hist = [] 246 | loss_component_hist = [] 247 | epred_metric_hist = [] 248 | 249 | # tensor board 250 | if args["tensorboard"]: 251 | log_dir = ( 252 | "/logs/tensorboard/" 253 | + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 254 | + "-" 255 | + args["name"] 256 | ) 257 | tb = SummaryWriter(log_dir=log_dir, comment=args["name"]) 258 | check_counter = 0 259 | 260 | eval(0) 261 | 262 | for epoch in range(args["n_epoch"]): 263 | 264 | start = time.time() 265 | loss, loss_component, epred_metric = train(epoch) 266 | elapsed = time.time() - start 267 | 268 | loss_hist.append(loss) 269 | loss_component_hist.append(loss_component) 270 | epred_metric_hist.append(epred_metric) 271 | 272 | print( 273 | f"#{epoch:3d}, " 274 | + f"Loss: {loss:.4f} => xu: {loss_component['xu']:.4f}, xv: {loss_component['xv']:.4f}, " 275 | + f"xe: {loss_component['xe']:.4f}, " 276 | + f"e: {loss_component['e']:.4f} -> " 277 | + f"[acc: {epred_metric['acc']:.3f}, f1: {epred_metric['f1']:.3f} -> " 278 | + f"prec: {epred_metric['prec']:.3f}, rec: {epred_metric['rec']:.3f}] " 279 | + f"> {elapsed:.2f}s" 280 | ) 281 | 282 | if epoch % args["iter_check"] == 0: # and epoch != 0: 283 | # tb eval 284 | eval(epoch) 285 | 286 | 287 | # %% after training 288 | res = eval(args["n_epoch"]) 289 | ev_loss, ev_loss_component, ev_epred_metric = res 290 | 291 | if args["tensorboard"]: 292 | tb.add_hparams( 293 | args, 294 | { 295 | "loss": ev_loss, 296 | "xu": ev_loss_component["xu"], 297 | "xv": ev_loss_component["xv"], 298 | "xe": ev_loss_component["xe"], 299 | "e": ev_loss_component["e"], 300 | "acc": ev_epred_metric["acc"], 301 | "f1": ev_epred_metric["f1"], 302 | "prec": ev_epred_metric["prec"], 303 | "rec": ev_epred_metric["rec"], 304 | }, 305 | ) 306 | 307 | print() 308 | print(args) 309 | -------------------------------------------------------------------------------- /train_sample_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import sys 5 | 6 | from data_finefoods import load_graph 7 | from models.score import compute_evaluation_metrics 8 | 9 | import time 10 | from tqdm import tqdm 11 | import argparse 12 | import os 13 | 14 | from torch.utils.tensorboard import SummaryWriter 15 | import datetime 16 | 17 | import torch 18 | 19 | from models.data import BipartiteData 20 | from models.net_sample import GraphBEANSampled 21 | from models.sampler import BipartiteNeighborSampler 22 | from models.sampler import EdgePredictionSampler 23 | from models.loss import reconstruction_loss 24 | from models.score import compute_anomaly_score, edge_prediction_metric 25 | 26 | from utils.sum_dict import dict_addto, dict_div 27 | from utils.seed import seed_all 28 | 29 | # %% args 30 | 31 | parser = argparse.ArgumentParser(description="GraphBEAN") 32 | parser.add_argument("--name", type=str, default="finefoods_anomaly", help="name") 33 | parser.add_argument( 34 | "--key", type=str, default="graph_anomaly_list", help="key to the data" 35 | ) 36 | parser.add_argument("--id", type=int, default=0, help="id to the data") 37 | parser.add_argument("--batch-size", type=int, default=2048, help="batch size") 38 | parser.add_argument( 39 | "--num-neighbors-u", 40 | type=int, 41 | default=10, 42 | help="number of neighbors for node u in sampling", 43 | ) 44 | parser.add_argument( 45 | "--num-neighbors-v", 46 | type=int, 47 | default=10, 48 | help="number of neighbors for node v in sampling", 49 | ) 50 | parser.add_argument("--n-epoch", type=int, default=50, help="number of epoch") 51 | parser.add_argument( 52 | "--scheduler-milestones", 53 | nargs="+", 54 | type=int, 55 | default=[20, 35], 56 | help="scheduler milestone", 57 | ) 58 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") 59 | parser.add_argument( 60 | "--score-agg", type=str, default="max", help="aggregation for node anomaly score" 61 | ) 62 | parser.add_argument( 63 | "--num-workers", 64 | type=int, 65 | default=0, 66 | help="number of workers in neighborhood sampling loader", 67 | ) 68 | 69 | args1 = vars(parser.parse_args()) 70 | 71 | args2 = { 72 | "hidden_channels": 32, 73 | "latent_channels_u": 32, 74 | "latent_channels_v": 32, 75 | "edge_pred_latent": 32, 76 | "n_layers_encoder": 2, 77 | "n_layers_decoder": 2, 78 | "n_layers_mlp": 2, 79 | "dropout_prob": 0.0, 80 | "gamma": 0.2, 81 | "xe_loss_weight": 1.0, 82 | "structure_loss_weight": 0.2, 83 | "structure_loss_weight_anomaly_score": 0.2, 84 | "iter_check": 10, 85 | "seed": 0, 86 | "neg_sampler_mult": 3, 87 | "k_check": 15, 88 | "tensorboard": False, 89 | "progress_bar": False, 90 | } 91 | 92 | args = {**args1, **args2} 93 | 94 | seed_all(args["seed"]) 95 | 96 | result_dir = "results/" 97 | 98 | 99 | # %% params 100 | batch_size = args["batch_size"] 101 | 102 | # %% data 103 | data = load_graph(args["name"], args["key"], args["id"]) 104 | print(data) 105 | 106 | u_ch = data.xu.shape[1] 107 | v_ch = data.xv.shape[1] 108 | e_ch = data.xe.shape[1] 109 | 110 | print( 111 | f"Data dimension: U node = {data.xu.shape}; V node = {data.xv.shape}; E edge = {data.xe.shape}; \n" 112 | ) 113 | 114 | # %% model 115 | 116 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 117 | model = GraphBEANSampled( 118 | in_channels=(u_ch, v_ch, e_ch), 119 | hidden_channels=args["hidden_channels"], 120 | latent_channels=(args["latent_channels_u"], args["latent_channels_v"]), 121 | edge_pred_latent=args["edge_pred_latent"], 122 | n_layers_encoder=args["n_layers_encoder"], 123 | n_layers_decoder=args["n_layers_decoder"], 124 | n_layers_mlp=args["n_layers_mlp"], 125 | dropout_prob=args["dropout_prob"], 126 | ) 127 | 128 | model = model.to(device) 129 | optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"]) 130 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 131 | optimizer, milestones=args["scheduler_milestones"], gamma=args["gamma"] 132 | ) 133 | 134 | xu, xv = data.xu, data.xv 135 | xe, adj = data.xe, data.adj 136 | yu, yv, ye = data.yu, data.yv, data.ye 137 | 138 | # sampler 139 | train_loader = BipartiteNeighborSampler( 140 | adj, 141 | n_layers=4, 142 | base="v", 143 | batch_size=batch_size, 144 | drop_last=True, 145 | n_other_node=-1, 146 | num_neighbors_u=args["num_neighbors_u"], 147 | num_neighbors_v=args["num_neighbors_v"], 148 | num_workers=args["num_workers"], 149 | shuffle=True, 150 | ) 151 | 152 | print(args) 153 | print() 154 | 155 | # %% train 156 | def train(epoch, check_counter): 157 | 158 | model.train() 159 | 160 | n_batch = len(train_loader) 161 | if args["progress_bar"]: 162 | pbar = tqdm(total=n_batch, leave=False) 163 | pbar.set_description(f"#{epoch:3d}") 164 | 165 | total_loss = 0 166 | total_epred_metric = {"acc": 0.0, "prec": 0.0, "rec": 0.0, "f1": 0.0} 167 | total_loss_component = {"xu": 0.0, "xv": 0.0, "xe": 0.0, "e": 0.0, "total": 0.0} 168 | num_update = 0 169 | 170 | for batch_size, indices, adjacencies, e_flags in train_loader: 171 | 172 | # print(f"# u nodes: {len(indices[0])} | # v nodes: {len(indices[1])} | # edges: {len(indices[2])}") 173 | 174 | adjacencies = [adj.to(device) for adj in adjacencies] 175 | e_flags = [fl.to(device) for fl in e_flags] 176 | u_id, v_id, e_id = indices 177 | 178 | # sample 179 | xu_sample = xu[u_id].to(device) 180 | xv_sample = xv[v_id].to(device) 181 | xe_sample = xe[e_id].to(device) 182 | 183 | # edge pred samples 184 | target_adj = adjacencies[-1].adj_e.adj 185 | edge_pred_sampler = EdgePredictionSampler( 186 | target_adj, mult=args["neg_sampler_mult"] 187 | ) 188 | edge_pred_samples = edge_pred_sampler.sample().to(device) 189 | 190 | optimizer.zero_grad() 191 | 192 | # start = time.time() 193 | out = model( 194 | xu=xu_sample, 195 | xv=xv_sample, 196 | xe=xe_sample, 197 | bean_adjs=adjacencies, 198 | e_flags=e_flags, 199 | edge_pred_samples=edge_pred_samples, 200 | ) 201 | # print(f"training : {time.time() - start} s") 202 | 203 | last_adj_e = adjacencies[-1].adj_e 204 | xu_target = xu[last_adj_e.u_id].to(device) 205 | xv_target = xv[last_adj_e.v_id].to(device) 206 | xe_target = xe[last_adj_e.e_id].to(device) 207 | 208 | loss, loss_component = reconstruction_loss( 209 | xu=xu_target, 210 | xv=xv_target, 211 | xe=xe_target, 212 | adj=last_adj_e.adj, 213 | edge_pred_samples=edge_pred_samples, 214 | out=out, 215 | xe_loss_weight=args["xe_loss_weight"], 216 | structure_loss_weight=args["structure_loss_weight"], 217 | ) 218 | 219 | loss.backward() 220 | optimizer.step() 221 | 222 | epred_metric = edge_prediction_metric(edge_pred_samples, out["eprob"]) 223 | 224 | total_loss += float(loss) 225 | total_epred_metric = dict_addto(total_epred_metric, epred_metric) 226 | total_loss_component = dict_addto(total_loss_component, loss_component) 227 | num_update += 1 228 | 229 | if args["progress_bar"]: 230 | pbar.update(1) 231 | pbar.set_postfix( 232 | { 233 | "loss": float(loss), 234 | "ep acc": epred_metric["acc"], 235 | "ep f1": epred_metric["f1"], 236 | } 237 | ) 238 | 239 | if num_update == args["k_check"]: 240 | loss = total_loss / num_update 241 | loss_component = dict_div(total_loss_component, num_update) 242 | epred_metric = dict_div(total_epred_metric, num_update) 243 | 244 | # tensorboard 245 | if args["tensorboard"]: 246 | tb.add_scalar("loss", loss, check_counter) 247 | tb.add_scalar("loss_xu", loss_component["xu"], check_counter) 248 | tb.add_scalar("loss_xv", loss_component["xv"], check_counter) 249 | tb.add_scalar("loss_xe", loss_component["xe"], check_counter) 250 | tb.add_scalar("loss_e", loss_component["e"], check_counter) 251 | 252 | tb.add_scalar("epred_acc", epred_metric["acc"], check_counter) 253 | tb.add_scalar("epred_f1", epred_metric["f1"], check_counter) 254 | tb.add_scalar("epred_prec", epred_metric["prec"], check_counter) 255 | tb.add_scalar("epred_rec", epred_metric["rec"], check_counter) 256 | 257 | check_counter += 1 258 | 259 | total_loss = 0 260 | total_epred_metric = {"acc": 0.0, "prec": 0.0, "rec": 0.0, "f1": 0.0} 261 | total_loss_component = { 262 | "xu": 0.0, 263 | "xv": 0.0, 264 | "xe": 0.0, 265 | "e": 0.0, 266 | "total": 0.0, 267 | } 268 | num_update = 0 269 | 270 | if args["progress_bar"]: 271 | pbar.close() 272 | scheduler.step() 273 | 274 | return loss, loss_component, epred_metric, check_counter 275 | 276 | 277 | # %% evaluate and store 278 | def eval(epoch): 279 | 280 | model.eval() 281 | 282 | start = time.time() 283 | 284 | # negative sampling 285 | edge_pred_sampler = EdgePredictionSampler(adj, mult=args["neg_sampler_mult"]) 286 | edge_pred_samples = edge_pred_sampler.sample() 287 | 288 | with torch.no_grad(): 289 | 290 | out = model.inference( 291 | xu, 292 | xv, 293 | xe, 294 | adj, 295 | edge_pred_samples, 296 | batch_sizes=(2**13, 2**13, 2**13), 297 | device=device, 298 | progress_bar=args["progress_bar"], 299 | ) 300 | 301 | loss, loss_component = reconstruction_loss( 302 | xu, 303 | xv, 304 | xe, 305 | adj, 306 | edge_pred_samples, 307 | out, 308 | xe_loss_weight=args["xe_loss_weight"], 309 | structure_loss_weight=args["structure_loss_weight"], 310 | ) 311 | 312 | epred_metric = edge_prediction_metric(edge_pred_samples, out["eprob"]) 313 | 314 | anomaly_score = compute_anomaly_score( 315 | xu, 316 | xv, 317 | xe, 318 | adj, 319 | edge_pred_samples, 320 | out, 321 | xe_loss_weight=args["xe_loss_weight"], 322 | structure_loss_weight=args["structure_loss_weight_anomaly_score"], 323 | ) 324 | 325 | eval_metrics = compute_evaluation_metrics( 326 | anomaly_score, yu, yv, ye, agg=args["score_agg"] 327 | ) 328 | 329 | elapsed = time.time() - start 330 | 331 | print( 332 | f"Eval, loss: {loss:.4f}, " 333 | + f"u auc-roc: {eval_metrics['u_roc_auc']:.4f}, v auc-roc: {eval_metrics['v_roc_auc']:.4f}, e auc-roc: {eval_metrics['e_roc_auc']:.4f}, " 334 | + f"u auc-pr {eval_metrics['u_pr_auc']:.4f}, v auc-pr {eval_metrics['v_pr_auc']:.4f}, e auc-pr {eval_metrics['e_pr_auc']:.4f} " 335 | + f"> {elapsed:.2f}s" 336 | ) 337 | 338 | if args["tensorboard"]: 339 | tb.add_scalar("loss", loss, epoch) 340 | tb.add_scalar("u_roc_auc", eval_metrics["u_roc_auc"], epoch) 341 | tb.add_scalar("u_pr_auc", eval_metrics["u_pr_auc"], epoch) 342 | tb.add_scalar("v_roc_auc", eval_metrics["v_roc_auc"], epoch) 343 | tb.add_scalar("v_pr_auc", eval_metrics["v_pr_auc"], epoch) 344 | tb.add_scalar("e_roc_auc", eval_metrics["e_roc_auc"], epoch) 345 | tb.add_scalar("e_pr_auc", eval_metrics["e_pr_auc"], epoch) 346 | 347 | model_stored = { 348 | "args": args, 349 | "loss": loss, 350 | "loss_component": loss_component, 351 | "epred_metric": epred_metric, 352 | "eval_metrics": eval_metrics, 353 | "loss_hist": loss_hist, 354 | "loss_component_hist": loss_component_hist, 355 | "epred_metric_hist": epred_metric_hist, 356 | "state_dict": model.state_dict(), 357 | "optimizer_state_dict": optimizer.state_dict(), 358 | } 359 | output_stored = {"args": args, "out": out, "anomaly_score": anomaly_score} 360 | 361 | print("Saving current results...") 362 | torch.save( 363 | model_stored, 364 | os.path.join(result_dir, f"graphbean-{args['name']}-{args['id']}-model.th"), 365 | ) 366 | torch.save( 367 | output_stored, 368 | os.path.join(result_dir, f"graphbean-{args['name']}-{args['id']}-output.th"), 369 | ) 370 | 371 | return loss, loss_component, epred_metric 372 | 373 | 374 | # %% run training 375 | loss_hist = [] 376 | loss_component_hist = [] 377 | epred_metric_hist = [] 378 | 379 | # tensor board 380 | if args["tensorboard"]: 381 | log_dir = ( 382 | "/logs/tensorboard/" 383 | + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 384 | + "-" 385 | + args["name"] 386 | ) 387 | tb = SummaryWriter(log_dir=log_dir, comment=args["name"]) 388 | check_counter = 0 389 | 390 | # eval(0) 391 | 392 | for epoch in range(args["n_epoch"]): 393 | 394 | start = time.time() 395 | loss, loss_component, epred_metric, check_counter = train(epoch, check_counter) 396 | elapsed = time.time() - start 397 | 398 | loss_hist.append(loss) 399 | loss_component_hist.append(loss_component) 400 | epred_metric_hist.append(epred_metric) 401 | 402 | print( 403 | f"#{epoch:3d}, " 404 | + f"Loss: {loss:.4f} => xu: {loss_component['xu']:.4f}, xv: {loss_component['xv']:.4f}, " 405 | + f"xe: {loss_component['xe']:.4f}, " 406 | + f"e: {loss_component['e']:.4f} -> " 407 | + f"[acc: {epred_metric['acc']:.3f}, f1: {epred_metric['f1']:.3f} -> " 408 | + f"prec: {epred_metric['prec']:.3f}, rec: {epred_metric['rec']:.3f}] " 409 | + f"> {elapsed:.2f}s" 410 | ) 411 | 412 | if epoch % args["iter_check"] == 0: # and epoch != 0: 413 | # tb eval 414 | eval(epoch) 415 | 416 | 417 | # %% after training 418 | res = eval(args["n_epoch"]) 419 | ev_loss, ev_loss_component, ev_epred_metric = res 420 | 421 | if args["tensorboard"]: 422 | tb.add_hparams( 423 | args, 424 | { 425 | "loss": ev_loss, 426 | "xu": ev_loss_component["xu"], 427 | "xv": ev_loss_component["xv"], 428 | "xe": ev_loss_component["xe"], 429 | "e": ev_loss_component["e"], 430 | "acc": ev_epred_metric["acc"], 431 | "f1": ev_epred_metric["f1"], 432 | "prec": ev_epred_metric["prec"], 433 | "rec": ev_epred_metric["rec"], 434 | }, 435 | ) 436 | 437 | print() 438 | print(args) 439 | -------------------------------------------------------------------------------- /utils/seed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import numpy as np 5 | import random 6 | import torch 7 | 8 | 9 | def seed_all(seed_num): 10 | np.random.seed(seed_num) 11 | random.seed(seed_num) 12 | torch.manual_seed(seed_num) 13 | torch.cuda.manual_seed_all(seed_num) 14 | -------------------------------------------------------------------------------- /utils/sparse_combine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from torch_sparse import SparseTensor 8 | 9 | from typing import Optional, Tuple 10 | 11 | from torch_sparse import SparseTensor 12 | from torch_sparse.storage import SparseStorage 13 | 14 | 15 | def spadd(A: SparseTensor, B: SparseTensor, op: str = "add") -> SparseTensor: 16 | assert A.sparse_sizes() == B.sparse_sizes() 17 | 18 | m, n = A.sparse_sizes() 19 | 20 | row = torch.cat([A.storage.row(), B.storage.row()], dim=-1) 21 | col = torch.cat([A.storage.col(), B.storage.col()], dim=-1) 22 | value = torch.cat([A.storage.value(), B.storage.value()], dim=0) 23 | 24 | storage = SparseStorage( 25 | row=row, col=col, value=value, sparse_sizes=(m, n), is_sorted=False 26 | ) 27 | storage = storage.coalesce(reduce=op) 28 | 29 | return SparseTensor.from_storage(storage) 30 | 31 | 32 | ## sparse combine 33 | def sparse_combine( 34 | a: SparseTensor, b: SparseTensor, flag_mult: Optional[Tuple[int, int]] = (1, 2) 35 | ) -> Tuple[SparseTensor, SparseTensor]: 36 | 37 | res = spadd(a, b, op="mean") 38 | 39 | # flag where the source come from 40 | flag = spadd(a.fill_value(flag_mult[0]), b.fill_value(flag_mult[1])) 41 | 42 | return res, flag 43 | 44 | 45 | ## sparse combine 46 | def sparse_combine3( 47 | a: SparseTensor, b: SparseTensor, c: SparseTensor 48 | ) -> Tuple[SparseTensor, SparseTensor, SparseTensor]: 49 | 50 | flag_mult = (1, 2, 4) 51 | res = spadd(spadd(a, b, op="mean"), c, op="mean") 52 | 53 | # flag where the source come from 54 | flag = spadd( 55 | spadd(a.fill_value(flag_mult[0]), b.fill_value(flag_mult[1])), 56 | c.fill_value(flag_mult[2]), 57 | ) 58 | 59 | return res, flag 60 | 61 | 62 | def sparse_combine3a( 63 | a: SparseTensor, b: SparseTensor, c: SparseTensor 64 | ) -> Tuple[SparseTensor, SparseTensor, SparseTensor]: 65 | 66 | flag_mult = (1, 2, 4) 67 | 68 | # add values 69 | d = SparseTensor.from_torch_sparse_coo_tensor( 70 | a.to_torch_sparse_coo_tensor() 71 | + b.to_torch_sparse_coo_tensor() 72 | + c.to_torch_sparse_coo_tensor() 73 | ) 74 | # add non zeros 75 | e = SparseTensor.from_torch_sparse_coo_tensor( 76 | a.fill_value(1).to_torch_sparse_coo_tensor() 77 | + b.fill_value(1).to_torch_sparse_coo_tensor() 78 | + c.fill_value(1).to_torch_sparse_coo_tensor() 79 | ) 80 | 81 | # rmove duplicate values 82 | val = (d.storage.value() / e.storage.value()).long() 83 | res = d.set_value(val, layout="coo") 84 | 85 | # flag where the source come from 86 | flag = SparseTensor.from_torch_sparse_coo_tensor( 87 | a.fill_value(flag_mult[0]).to_torch_sparse_coo_tensor() 88 | + b.fill_value(flag_mult[1]).to_torch_sparse_coo_tensor() 89 | + c.fill_value(flag_mult[2]).to_torch_sparse_coo_tensor() 90 | ) 91 | 92 | return res, flag 93 | 94 | 95 | def xe_split3( 96 | xe: Tensor, 97 | flag: Tensor, 98 | ) -> Tuple[Tensor, Tensor, Tensor]: 99 | 100 | # flag_mult = (1,2,4) 101 | 102 | a_idx = (flag == 1) | (flag == 3) | (flag == 5) | (flag == 7) 103 | b_idx = (flag == 2) | (flag == 3) | (flag == 6) | (flag == 7) 104 | c_idx = (flag == 4) | (flag == 5) | (flag == 6) | (flag == 7) 105 | 106 | xe_a = xe[a_idx] 107 | xe_b = xe[b_idx] 108 | xe_c = xe[c_idx] 109 | 110 | return xe_a, xe_b, xe_c 111 | -------------------------------------------------------------------------------- /utils/sprand.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | import torch 5 | from torch_sparse import SparseTensor, SparseStorage 6 | from typing import Tuple 7 | 8 | 9 | def sprand(dim: Tuple[int, int], nnz: int) -> SparseTensor: 10 | nu, nv = dim 11 | row = torch.randint(nu, (nnz,)) 12 | col = torch.randint(nv, (nnz,)) 13 | 14 | storage = SparseStorage(row=row, col=col, sparse_sizes=(nu, nv), is_sorted=False) 15 | storage = storage.coalesce(reduce="max") 16 | 17 | return SparseTensor.from_storage(storage) 18 | -------------------------------------------------------------------------------- /utils/sum_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Grabtaxi Holdings Pte Ltd (GRAB), All rights reserved. 2 | # Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | 5 | def dict_addto(res: dict, a: dict) -> dict: 6 | for k in res.keys(): 7 | res[k] += float(a[k]) 8 | return res 9 | 10 | 11 | def dict_div(res: dict, div) -> dict: 12 | for k in res.keys(): 13 | res[k] /= div 14 | return res 15 | --------------------------------------------------------------------------------