├── .gitignore ├── LICENSE ├── README.md ├── config ├── amazon.yml ├── tfinance.yml ├── tsocial.yml └── yelp.yml ├── data └── dataset source.txt ├── framework.png ├── main.py ├── model-weights ├── amazon.pth ├── tfinance.pth └── yelp.pth ├── models.py ├── modules ├── aux_mod.py ├── conv_mod.py ├── data_loader.py ├── evaluation.py ├── loss.py ├── mod_utls.py └── mr_conv_mod.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .DS_Store 161 | .idea/ 162 | 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2024 Xtra Computing Group, NUS, Singapore. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction: 2 | This is the code for the ICLR 2024 paper of **ConsisGAD**: [Consistency Training with Learnable Data Augmentation for Graph Anomaly Detection with Limited Supervision.](https://openreview.net/forum?id=elMKXvhhQ9) 3 | 4 | In this work, we propose a novel framework, ConsisGAD, which is tailored for graph anomaly detection in scenarios characterized by limited supervision and is anchored in the principles of consistency training. Under limited supervision, ConsisGAD effectively leverages the abundance of unlabeled data for consistency training by incorporating a novel learnable data augmentation mechanism, thereby introducing controlled noise into the dataset. Moreover, ConsisGAD takes advantage of the variance in homophily distribution between normal and anomalous nodes to craft a simplified GNN backbone, enhancing its capability to effectively distinguish between these two classes. A brief overview of our framework is illustrated in the following picture. 5 | 6 |

7 | Overall framework of ConsisGAD. 8 |

9 | 10 | This repository contains the source code for our Graph Neural Network (GNN) backbone, consistency training procedure, and learnable data augmentation module. Below is an overview of the key components and their locations within the repository: 11 | 12 | - **GNN Backbone Model**: The core implementation of our GNN backbone model is encapsulated within the `simpleGNN_MR` class located in the `models.py` file. 13 | 14 | - **Consistency Training Procedure**: The consistency training procedure is implemented through the `UDA_train_epoch` function, which can be found in the `main.py` file. 15 | 16 | - **Learnable Data Augmentation**: Our learnable data augmentation is realized via the `SoftAttentionDrop` class, which is also located in the `main.py` file. 17 | 18 | ## Directory Structure 19 | The repository is organized into several directories, each serving a specific purpose: 20 | 21 | - `data/`: This directory houses the datasets utilized in our work. 22 | 23 | - `config/`: This folder stores the hyper-parameter configuration of our model. 24 | 25 | - `modules/`: Auxiliary components of our model are stored in this directory. It includes important modules, such as the data loader `data_loader.py` and the evaluation pipeline `evaluation.py`. 26 | 27 | - `model-weights/`: Here, we store the trained weights of our model. 28 | 29 | ## Installation: 30 | - Install required packages: `pip install -r requirements.txt` 31 | - Dataset resources: 32 | - For Amazon and YelpChi, we use the built-in datasets in the DGL package https://docs.dgl.ai/en/0.8.x/api/python/dgl.data.html. 33 | - For T-Finance and T-Social, we download the datasets from https://github.com/squareRoot3/Rethinking-Anomaly-Detection. 34 | - Please download and unzip all the files in the `data/` folder. 35 | 36 | ## Usage: 37 | - Hyper-parameter settings for all datasets are put into the `config/` folder. 38 | - To run the model, use `--config` to specify hyper-parameters and `--runs` the number of running times. 39 | - If you want to run the YelpChi dataset 5 times, please execute this command: `python main.py --config 'config/yelp.yml' --runs 5`. 40 | 41 | ## Citation 42 | If you find our work useful, please cite: 43 | 44 | ``` 45 | @inproceedings{ 46 | chen2024consistency, 47 | title={Consistency Training with Learnable Data Augmentation for Graph Anomaly Detection with Limited Supervision}, 48 | author={Nan Chen and Zemin Liu and Bryan Hooi and Bingsheng He and Rizal Fathony and Jun Hu and Jia Chen}, 49 | booktitle={The Twelfth International Conference on Learning Representations}, 50 | year={2024}, 51 | url={https://openreview.net/forum?id=elMKXvhhQ9} 52 | } 53 | ``` 54 | 55 | Feel free to contact nanchansysu@gmail.com if you have any questions. 56 | -------------------------------------------------------------------------------- /config/amazon.yml: -------------------------------------------------------------------------------- 1 | data-set: 'amazon' 2 | to-homo: False 3 | shuffle-train: True 4 | model: 'backbone' 5 | hidden-dim: 64 6 | num-layers: 1 7 | epochs: 100 8 | lr: 0.001 9 | weight-decay: 0.00001 10 | device: 1 11 | training-ratio: 1 12 | train-procedure: 'CT' 13 | mlp-drop: 0.3 14 | input-drop: 0.0 15 | hidden-drop: 0.0 16 | mlp12-dim: 128 17 | mlp3-dim: 128 18 | bn-type: 2 19 | optim: 'adam' 20 | store-model: True 21 | trainable-consis-weight: 1.5 22 | trainable-temp: 0.0001 23 | trainable-eps: 0.000000000001 24 | trainable-drop-rate: 0.2 25 | trainable-warm-up: -1 26 | trainable-model: 'proj' 27 | trainable-optim: 'adam' 28 | trainable-lr: 0.01 29 | trainable-weight-decay: 0.0 30 | topk-mode: 4 31 | diversity-type: 'euc' 32 | unlabel-ratio: 6 33 | normal-th: 5 34 | fraud-th: 85 35 | trainable-detach-y: True 36 | trainable-div-eps: True 37 | trainable-detach-mask: False 38 | batch-size: 32 39 | train-iterations: 128 40 | -------------------------------------------------------------------------------- /config/tfinance.yml: -------------------------------------------------------------------------------- 1 | data-set: 'tfinance' 2 | to-homo: False 3 | shuffle-train: True 4 | model: 'backbone' 5 | hidden-dim: 64 6 | num-layers: 1 7 | epochs: 100 8 | lr: 0.001 9 | weight-decay: 0.00001 10 | device: 3 11 | training-ratio: 1 12 | train-procedure: 'CT' 13 | mlp-drop: 0.5 14 | input-drop: 0.0 15 | hidden-drop: 0.0 16 | mlp12-dim: 64 17 | mlp3-dim: 128 18 | bn-type: 2 19 | optim: 'adam' 20 | store-model: True 21 | trainable-consis-weight: 1.0 22 | trainable-temp: 0.0001 23 | trainable-eps: 0.000000000001 24 | trainable-drop-rate: 0.2 25 | trainable-warm-up: -1 26 | trainable-model: 'proj' 27 | trainable-optim: 'adam' 28 | trainable-lr: 0.005 29 | trainable-weight-decay: 0.0 30 | topk-mode: 4 31 | diversity-type: 'euc' 32 | unlabel-ratio: 4 33 | normal-th: 5 34 | fraud-th: 88 35 | trainable-detach-y: True 36 | trainable-div-eps: True 37 | trainable-detach-mask: False 38 | batch-size: 128 39 | train-iterations: 128 40 | 41 | -------------------------------------------------------------------------------- /config/tsocial.yml: -------------------------------------------------------------------------------- 1 | data-set: 'tsocial' 2 | to-homo: False 3 | shuffle-train: True 4 | model: 'backbone' 5 | hidden-dim: 64 6 | num-layers: 1 7 | epochs: 100 8 | lr: 0.001 9 | weight-decay: 0.00001 10 | device: 4 11 | training-ratio: 0.01 12 | train-procedure: 'CT' 13 | mlp-drop: 0.4 14 | input-drop: 0.0 15 | hidden-drop: 0.0 16 | mlp12-dim: 128 17 | mlp3-dim: 128 18 | bn-type: 2 19 | optim: 'adam' 20 | store-model: True 21 | trainable-consis-weight: 1.5 22 | trainable-temp: 0.0001 23 | trainable-eps: 0.000000000001 24 | trainable-drop-rate: 0.2 25 | trainable-warm-up: -1 26 | trainable-model: 'proj' 27 | trainable-optim: 'adam' 28 | trainable-lr: 0.005 29 | trainable-weight-decay: 0.0 30 | topk-mode: 4 31 | diversity-type: 'euc' 32 | unlabel-ratio: 5 33 | normal-th: 5 34 | fraud-th: 88 35 | trainable-detach-y: True 36 | trainable-div-eps: True 37 | trainable-detach-mask: False 38 | batch-size: 128 39 | train-iterations: 128 40 | -------------------------------------------------------------------------------- /config/yelp.yml: -------------------------------------------------------------------------------- 1 | data-set: 'yelp' 2 | to-homo: False 3 | shuffle-train: True 4 | model: 'backbone' 5 | hidden-dim: 64 6 | num-layers: 1 7 | epochs: 100 8 | lr: 0.001 9 | weight-decay: 0.00001 10 | device: 2 11 | training-ratio: 1 12 | train-procedure: 'CT' 13 | mlp-drop: 0.4 14 | input-drop: 0.0 15 | hidden-drop: 0.0 16 | mlp12-dim: 128 17 | mlp3-dim: 128 18 | bn-type: 2 19 | optim: 'adam' 20 | store-model: True 21 | trainable-consis-weight: 1.5 22 | trainable-temp: 0.0001 23 | trainable-eps: 0.000000000001 24 | trainable-drop-rate: 0.2 25 | trainable-warm-up: -1 26 | trainable-model: 'mlp' 27 | trainable-optim: 'adam' 28 | trainable-lr: 0.005 29 | trainable-weight-decay: 0.00001 30 | topk-mode: 4 31 | diversity-type: 'cos' 32 | unlabel-ratio: 4 33 | normal-th: 7 34 | fraud-th: 88 35 | trainable-detach-y: True 36 | trainable-div-eps: True 37 | trainable-detach-mask: False 38 | batch-size: 128 39 | train-iterations: 128 40 | -------------------------------------------------------------------------------- /data/dataset source.txt: -------------------------------------------------------------------------------- 1 | For Yelp and Amazon, please refer to https://docs.dgl.ai/en/0.8.x/api/python/dgl.data.html. 2 | 3 | For T-Social and T-Finance, please refer to https://github.com/squareRoot3/Rethinking-Anomaly-Detection. 4 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/framework.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import csv 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from modules.data_loader import get_index_loader_test 11 | from models import simpleGNN_MR 12 | import modules.mod_utls as m_utls 13 | from modules.loss import nll_loss, l2_regularization, nll_loss_raw 14 | from modules.evaluation import eval_pred 15 | from modules.aux_mod import fixed_augmentation 16 | from sklearn.metrics import f1_score 17 | from modules.conv_mod import CustomLinear 18 | from modules.mr_conv_mod import build_mlp 19 | import numpy as np 20 | from numpy import random 21 | import math 22 | import pandas as pd 23 | from functools import partial 24 | import dgl 25 | import warnings 26 | import wandb 27 | import yaml 28 | warnings.filterwarnings("ignore") 29 | 30 | 31 | class SoftAttentionDrop(nn.Module): 32 | def __init__(self, args): 33 | super(SoftAttentionDrop, self).__init__() 34 | dim = args['hidden-dim'] 35 | 36 | self.temp = args['trainable-temp'] 37 | self.p = args['trainable-drop-rate'] 38 | if args['trainable-model'] == 'proj': 39 | self.mask_proj = CustomLinear(dim, dim) 40 | else: 41 | self.mask_proj = build_mlp(in_dim=dim, out_dim=dim, p=args['mlp-drop'], final_act=False) 42 | 43 | self.detach_y = args['trainable-detach-y'] 44 | self.div_eps = args['trainable-div-eps'] 45 | self.detach_mask = args['trainable-detach-mask'] 46 | 47 | def forward(self, feature, in_eval=False): 48 | mask = self.mask_proj(feature) 49 | 50 | y = torch.zeros_like(mask) 51 | k = round(mask.shape[1] * self.p) 52 | 53 | for _ in range(k): 54 | if self.detach_y: 55 | w = torch.zeros_like(y) 56 | w[y>0.5] = 1 57 | w = (1. - w).detach() 58 | else: 59 | w = (1. - y) 60 | 61 | logw = torch.log(w + 1e-12) 62 | y1 = (mask + logw) / self.temp 63 | y1 = y1 - torch.amax(y1, dim=1, keepdim=True) 64 | 65 | if self.div_eps: 66 | y1 = torch.exp(y1) / (torch.sum(torch.exp(y1), dim=1, keepdim=True) + args['trainable-eps']) 67 | else: 68 | y1 = torch.exp(y1) / torch.sum(torch.exp(y1), dim=1, keepdim=True) 69 | 70 | y = y + y1 * w 71 | 72 | mask = 1. - y 73 | mask = mask / (1. - self.p) 74 | 75 | if in_eval and self.detach_mask: 76 | mask = mask.detach() 77 | 78 | return feature * mask 79 | 80 | 81 | def create_model(args, e_ts): 82 | if args['model'] == 'backbone': 83 | tmp_model = simpleGNN_MR(in_feats=args['node-in-dim'], hidden_feats=args['hidden-dim'], out_feats=args['node-out-dim'], 84 | num_layers=args['num-layers'], e_types=e_ts, input_drop=args['input-drop'], hidden_drop=args['hidden-drop'], 85 | mlp_drop=args['mlp-drop'], mlp12_dim=args['mlp12-dim'], mlp3_dim=args['mlp3-dim'], bn_type=args['bn-type']) 86 | else: 87 | raise 88 | tmp_model.to(args['device']) 89 | 90 | return tmp_model 91 | 92 | 93 | def UDA_train_epoch(epoch, model, loss_func, graph, label_loader, unlabel_loader, optimizer, augmentor, args): 94 | model.train() 95 | num_iters = args['train-iterations'] 96 | 97 | sampler, attn_drop, ad_optim = augmentor 98 | 99 | unlabel_loader_iter = iter(unlabel_loader) 100 | label_loader_iter = iter(label_loader) 101 | 102 | for idx in range(num_iters): 103 | try: 104 | label_idx = label_loader_iter.__next__() 105 | except: 106 | label_loader_iter = iter(label_loader) 107 | label_idx = label_loader_iter.__next__() 108 | try: 109 | unlabel_idx = unlabel_loader_iter.__next__() 110 | except: 111 | unlabel_loader_iter = iter(unlabel_loader) 112 | unlabel_idx = unlabel_loader_iter.__next__() 113 | 114 | if epoch > args['trainable-warm-up']: 115 | model.eval() 116 | with torch.no_grad(): 117 | _, _, u_blocks = fixed_augmentation(graph, unlabel_idx.to(args['device']), sampler, aug_type='none') 118 | weak_inter_results = model(u_blocks, update_bn=False, return_logits=True) 119 | weak_h = torch.stack(weak_inter_results, dim=1) 120 | weak_h = weak_h.reshape(weak_h.shape[0], -1) 121 | weak_logits = model.proj_out(weak_h) 122 | u_pred_weak_log = weak_logits.log_softmax(dim=-1) 123 | u_pred_weak = u_pred_weak_log.exp()[:, 1] 124 | 125 | pseudo_labels = torch.ones_like(u_pred_weak).long() 126 | neg_tar = (u_pred_weak <= (args['normal-th']/100.)).bool() 127 | pos_tar = (u_pred_weak >= (args['fraud-th']/100.)).bool() 128 | pseudo_labels[neg_tar] = 0 129 | pseudo_labels[pos_tar] = 1 130 | u_mask = torch.logical_or(neg_tar, pos_tar) 131 | 132 | model.train() 133 | attn_drop.train() 134 | for param in model.parameters(): 135 | param.requires_grad = False 136 | for param in attn_drop.parameters(): 137 | param.requires_grad = True 138 | 139 | _, _, u_blocks = fixed_augmentation(graph, unlabel_idx.to(args['device']), sampler, aug_type='drophidden') 140 | 141 | inter_results = model(u_blocks, update_bn=False, return_logits=True) 142 | dropped_results = [inter_results[0]] 143 | for i in range(1, len(inter_results)): 144 | dropped_results.append(attn_drop(inter_results[i])) 145 | h = torch.stack(dropped_results, dim=1) 146 | h = h.reshape(h.shape[0], -1) 147 | logits = model.proj_out(h) 148 | u_pred = logits.log_softmax(dim=-1) 149 | 150 | consistency_loss = nll_loss_raw(u_pred, pseudo_labels, pos_w=1.0, reduction='none') 151 | consistency_loss = torch.mean(consistency_loss * u_mask) 152 | 153 | if args['diversity-type'] == 'cos': 154 | diversity_loss = F.cosine_similarity(weak_h, h, dim=-1) 155 | elif args['diversity-type'] == 'euc': 156 | diversity_loss = F.pairwise_distance(weak_h, h) 157 | else: 158 | raise 159 | diversity_loss = torch.mean(diversity_loss * u_mask) 160 | 161 | total_loss = args['trainable-consis-weight'] * consistency_loss - diversity_loss + args['trainable-weight-decay'] * l2_regularization(attn_drop) 162 | 163 | ad_optim.zero_grad() 164 | total_loss.backward() 165 | ad_optim.step() 166 | 167 | for param in model.parameters(): 168 | param.requires_grad = True 169 | for param in attn_drop.parameters(): 170 | param.requires_grad = False 171 | 172 | inter_results = model(u_blocks, update_bn=False, return_logits=True) 173 | dropped_results = [inter_results[0]] 174 | for i in range(1, len(inter_results)): 175 | dropped_results.append(attn_drop(inter_results[i], in_eval=True)) 176 | 177 | h = torch.stack(dropped_results, dim=1) 178 | h = h.reshape(h.shape[0], -1) 179 | logits = model.proj_out(h) 180 | u_pred = logits.log_softmax(dim=-1) 181 | 182 | unsup_loss = nll_loss_raw(u_pred, pseudo_labels, pos_w=1.0, reduction='none') 183 | unsup_loss = torch.mean(unsup_loss * u_mask) 184 | else: 185 | unsup_loss = 0.0 186 | 187 | _, _, s_blocks = fixed_augmentation(graph, label_idx.to(args['device']), sampler, aug_type='none') 188 | s_pred = model(s_blocks) 189 | s_target = s_blocks[-1].dstdata['label'] 190 | 191 | sup_loss, _ = loss_func(s_pred, s_target) 192 | 193 | loss = sup_loss + unsup_loss + args['weight-decay'] * l2_regularization(model) 194 | 195 | optimizer.zero_grad() 196 | loss.backward() 197 | optimizer.step() 198 | 199 | 200 | def get_model_pred(model, graph, data_loader, sampler, args): 201 | model.eval() 202 | 203 | pred_list = [] 204 | target_list = [] 205 | with torch.no_grad(): 206 | for node_idx in data_loader: 207 | _, _, blocks = sampler.sample_blocks(graph, node_idx.to(args['device'])) 208 | 209 | pred = model(blocks) 210 | target = blocks[-1].dstdata['label'] 211 | 212 | pred_list.append(pred.detach()) 213 | target_list.append(target.detach()) 214 | pred_list = torch.cat(pred_list, dim=0) 215 | target_list = torch.cat(target_list, dim=0) 216 | pred_list = pred_list.exp()[:, 1] 217 | 218 | return pred_list, target_list 219 | 220 | 221 | def val_epoch(epoch, model, graph, valid_loader, test_loader, sampler, args): 222 | valid_dict = {} 223 | valid_pred, valid_target = get_model_pred(model, graph, valid_loader, sampler, args) 224 | v_roc, v_pr, _, _, _, _, v_f1, v_thre = eval_pred(valid_pred, valid_target) 225 | valid_dict['auc-roc'] = v_roc 226 | valid_dict['auc-pr'] = v_pr 227 | valid_dict['marco f1'] = v_f1 228 | 229 | test_dict = {} 230 | test_pred, test_target = get_model_pred(model, graph, test_loader, sampler, args) 231 | t_roc, t_pr, _, _, _, _, _, _ = eval_pred(test_pred, test_target) 232 | test_dict['auc-roc'] = t_roc 233 | test_dict['auc-pr'] = t_pr 234 | 235 | test_pred = test_pred.cpu().numpy() 236 | test_target = test_target.cpu().numpy() 237 | guessed_target = np.zeros_like(test_target) 238 | guessed_target[test_pred > v_thre] = 1 239 | t_f1 = f1_score(test_target, guessed_target, average='macro') 240 | test_dict['marco f1'] = t_f1 241 | 242 | return valid_dict, test_dict 243 | 244 | 245 | def run_model(args): 246 | graph, label_loader, valid_loader, test_loader, unlabel_loader = get_index_loader_test(name=args['data-set'], 247 | batch_size=args['batch-size'], 248 | unlabel_ratio=args['unlabel-ratio'], 249 | training_ratio=args['training-ratio'], 250 | shuffle_train=args['shuffle-train'], 251 | to_homo=args['to-homo']) 252 | graph = graph.to(args['device']) 253 | 254 | args['node-in-dim'] = graph.ndata['feature'].shape[1] 255 | args['node-out-dim'] = 2 256 | 257 | my_model = create_model(args, graph.etypes) 258 | 259 | if args['optim'] == 'adam': 260 | optimizer = optim.Adam(my_model.parameters(), lr=args['lr'], weight_decay=0.0) 261 | elif args['optim'] == 'rmsprop': 262 | optimizer = optim.RMSprop(my_model.parameters(), lr=args['lr'], weight_decay=0.0) 263 | 264 | sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['num-layers']) 265 | 266 | train_epoch = UDA_train_epoch 267 | attn_drop = SoftAttentionDrop(args).to(args['device']) 268 | if args['trainable-optim'] == 'rmsprop': 269 | ad_optim = optim.RMSprop(attn_drop.parameters(), lr=args['trainable-lr'], weight_decay=0.0) 270 | else: 271 | ad_optim = optim.Adam(attn_drop.parameters(), lr=args['trainable-lr'], weight_decay=0.0) 272 | augmentor = (sampler, attn_drop, ad_optim) 273 | 274 | task_loss = nll_loss 275 | 276 | best_val = sys.float_info.min 277 | for epoch in range(args['epochs']): 278 | train_epoch(epoch, my_model, task_loss, graph, label_loader, unlabel_loader, optimizer, augmentor, args) 279 | val_results, test_results = val_epoch(epoch, my_model, graph, valid_loader, test_loader, sampler, args) 280 | 281 | if val_results['auc-roc'] > best_val: 282 | best_val = val_results['auc-roc'] 283 | test_in_best_val = test_results 284 | 285 | if args['store-model']: 286 | m_utls.store_model(my_model, args) 287 | 288 | return list(test_in_best_val.values()) 289 | 290 | 291 | def get_config(config_path="config.yml"): 292 | with open(config_path, "r") as setting: 293 | config = yaml.load(setting, Loader=yaml.FullLoader) 294 | return config 295 | 296 | 297 | if __name__ == '__main__': 298 | start_time = time.time() 299 | 300 | parser = argparse.ArgumentParser() 301 | parser.add_argument('--config', required=True, type=str, help='Path to the config file.') 302 | parser.add_argument('--runs', type=int, default=1, help='Number of runs. Default is 1.') 303 | cfg = vars(parser.parse_args()) 304 | 305 | args = get_config(cfg['config']) 306 | if torch.cuda.is_available(): 307 | args['device'] = torch.device('cuda:%d'%(args['device'])) 308 | else: 309 | args['device'] = torch.device('cpu') 310 | 311 | print(args) 312 | final_results = [] 313 | for r in range(cfg['runs']): 314 | final_results.append(run_model(args)) 315 | 316 | final_results = np.array(final_results) 317 | mean_results = np.mean(final_results, axis=0) 318 | std_results = np.std(final_results, axis=0) 319 | 320 | print(mean_results) 321 | print(std_results) 322 | print('total time: ', time.time()-start_time) 323 | -------------------------------------------------------------------------------- /model-weights/amazon.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/model-weights/amazon.pth -------------------------------------------------------------------------------- /model-weights/tfinance.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/model-weights/tfinance.pth -------------------------------------------------------------------------------- /model-weights/yelp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/model-weights/yelp.pth -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | from typing import Optional, Tuple, Union 3 | import dgl 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import dgl.function as fn 8 | import modules.mod_utls as m_utls 9 | import numpy as np 10 | from modules.conv_mod import CustomLinear 11 | from modules.mr_conv_mod import build_mlp 12 | 13 | 14 | class CustomBatchNorm1d(nn.BatchNorm1d): 15 | def forward(self, input, update_running_stats: bool=True): 16 | self.track_running_stats = update_running_stats 17 | return super(CustomBatchNorm1d, self).forward(input) 18 | 19 | 20 | class MySimpleConv_MR_test(nn.Module): 21 | def __init__(self, in_feats: int, out_feats: int, e_types: list, drop_rate:float=0.0, 22 | mlp3_dim: int=64, bn_type: int=0): 23 | super(MySimpleConv_MR_test, self).__init__() 24 | self.e_types = e_types 25 | self.mlp3_dim = mlp3_dim 26 | self.bn_type = bn_type 27 | self.multi_relation = len(self.e_types) > 1 28 | 29 | self.proj_edges = nn.ModuleDict() 30 | for e_t in self.e_types: 31 | self.proj_edges[e_t] = build_mlp(in_feats * 2, out_feats, drop_rate, hid_dim=self.mlp3_dim) 32 | 33 | self.proj_out = CustomLinear(out_feats, out_feats, bias=True) 34 | if in_feats != out_feats: 35 | self.proj_skip = CustomLinear(in_feats, out_feats, bias=True) 36 | else: 37 | self.proj_skip = nn.Identity() 38 | 39 | if self.bn_type in [2, 3]: 40 | self.edge_bn = nn.ModuleDict() 41 | for e_t in self.e_types: 42 | self.edge_bn[e_t] = CustomBatchNorm1d(out_feats) 43 | 44 | def udf_edges(self, e_t: str): 45 | assert e_t in self.e_types, 'Invalid edge types!' 46 | tmp_fn = self.proj_edges[e_t] 47 | 48 | def fnc(edges): 49 | msg = torch.cat([edges.src['h'], edges.dst['h']], dim=-1) 50 | msg = tmp_fn(msg) 51 | return {'msg': msg} 52 | return fnc 53 | 54 | def forward(self, g, features, update_bn: bool=True): 55 | with g.local_scope(): 56 | src_feats = dst_feats = features 57 | if g.is_block: 58 | dst_feats = src_feats[:g.num_dst_nodes()] 59 | g.srcdata['h'] = src_feats 60 | g.dstdata['h'] = dst_feats 61 | 62 | for e_t in g.etypes: 63 | g.apply_edges(self.udf_edges(e_t), etype=e_t) 64 | 65 | if self.bn_type in [2, 3]: 66 | if not self.multi_relation: 67 | g.edata['msg'] = self.edge_bn[self.e_types[0]](g.edata['msg'], update_running_stats=update_bn) 68 | else: 69 | for e_t in g.canonical_etypes: 70 | g.edata['msg'][e_t] = self.edge_bn[e_t[1]](g.edata['msg'][e_t], update_running_stats=update_bn) 71 | 72 | etype_dict = {} 73 | for e_t in g.etypes: 74 | etype_dict[e_t] = (fn.copy_e('msg', 'msg'), fn.sum('msg', 'out')) 75 | g.multi_update_all(etype_dict=etype_dict, cross_reducer='stack') 76 | 77 | out = g.dstdata.pop('out') 78 | out = torch.sum(out, dim=1) 79 | out = self.proj_out(out) + self.proj_skip(dst_feats) 80 | 81 | return out 82 | 83 | 84 | class simpleGNN_MR(nn.Module): 85 | def __init__(self, in_feats: int, hidden_feats: int, out_feats: int, num_layers: int, e_types: list, 86 | input_drop: float, hidden_drop: float, mlp_drop: float, mlp12_dim: int, 87 | mlp3_dim: int, bn_type: int): 88 | super(simpleGNN_MR, self).__init__() 89 | self.gnn_list = nn.ModuleList() 90 | self.bn_list = nn.ModuleList() 91 | self.num_layers = num_layers 92 | self.input_drop = input_drop 93 | self.hidden_drop = hidden_drop 94 | self.mlp_drop = mlp_drop 95 | self.mlp12_dim = mlp12_dim 96 | self.mlp3_dim = mlp3_dim 97 | self.bn_type = bn_type 98 | 99 | self.proj_in = build_mlp(in_feats, hidden_feats, self.mlp_drop, hid_dim=self.mlp12_dim) 100 | in_feats = hidden_feats 101 | 102 | self.in_bn = None 103 | if self.bn_type in [1, 3]: 104 | self.in_bn = CustomBatchNorm1d(hidden_feats) 105 | 106 | for i in range(num_layers): 107 | in_dim = in_feats if i==0 else hidden_feats 108 | 109 | self.gnn_list.append( 110 | MySimpleConv_MR_test(in_feats=in_dim, out_feats=hidden_feats, 111 | e_types=e_types, drop_rate=self.mlp_drop, 112 | mlp3_dim=self.mlp3_dim, bn_type=self.bn_type)) 113 | 114 | self.bn_list.append(CustomBatchNorm1d(hidden_feats)) 115 | 116 | self.proj_out = build_mlp(hidden_feats*(num_layers+1), out_feats, self.mlp_drop, 117 | hid_dim=self.mlp12_dim, final_act=False) 118 | 119 | self.dropout = nn.Dropout(p=self.hidden_drop) 120 | self.dropout_in = nn.Dropout(p=self.input_drop) 121 | self.activation = F.selu 122 | 123 | def forward(self, blocks: list, update_bn: bool=True, return_logits: bool=False): 124 | final_num = blocks[-1].num_dst_nodes() 125 | h = blocks[0].srcdata['feature'] 126 | h = self.dropout_in(h) 127 | 128 | inter_results = [] 129 | h = self.proj_in(h) 130 | 131 | if self.in_bn is not None: 132 | h = self.in_bn(h, update_running_stats=update_bn) 133 | 134 | inter_results.append(h[:final_num]) 135 | for block, gnn, bn in zip(blocks, self.gnn_list, self.bn_list): 136 | h = gnn(block, h, update_bn) 137 | h = bn(h, update_running_stats=update_bn) 138 | h = self.activation(h) 139 | h = self.dropout(h) 140 | 141 | inter_results.append(h[:final_num]) 142 | 143 | if return_logits: 144 | return inter_results 145 | else: 146 | h = torch.stack(inter_results, dim=1) 147 | h = h.reshape(h.shape[0], -1) 148 | h = self.proj_out(h) 149 | return h.log_softmax(dim=-1) 150 | 151 | -------------------------------------------------------------------------------- /modules/aux_mod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl 3 | import dgl.function as fn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import os 8 | import time 9 | import csv 10 | import math 11 | import modules.mod_utls as m_utls 12 | import random 13 | from modules.conv_mod import CustomLinear 14 | from modules.mr_conv_mod import build_mlp 15 | 16 | 17 | Tensor = torch.tensor 18 | 19 | 20 | def fixed_augmentation(graph, seed_nodes, sampler, aug_type: str, p: float=None): 21 | assert aug_type in ['dropout', 'dropnode', 'dropedge', 'replace', 'drophidden', 'none'] 22 | with graph.local_scope(): 23 | if aug_type == 'dropout': 24 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes) 25 | blocks[0].srcdata['feature'] = F.dropout(blocks[0].srcdata['feature'], p) 26 | 27 | elif aug_type == 'dropnode': 28 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes) 29 | blocks[0].srcdata['feature'] = m_utls.drop_node(blocks[0].srcdata['feature'], p) 30 | 31 | elif aug_type == 'dropedge': 32 | del_edges = {} 33 | for et in graph.etypes: 34 | _, _, eid = graph.in_edges(seed_nodes, etype=et, form='all') 35 | num_remove = math.floor(eid.shape[0] * p) 36 | del_edges[et] = eid[torch.randperm(eid.shape[0])][:num_remove] 37 | aug_graph = graph 38 | for et in del_edges.keys(): 39 | aug_graph = dgl.remove_edges(aug_graph, del_edges[et], etype=et) 40 | input_nodes, output_nodes, blocks = sampler.sample_blocks(aug_graph, seed_nodes) 41 | 42 | elif aug_type == 'replace': 43 | raise Exception("The Replace sample is not implemented!") 44 | 45 | elif aug_type == 'drophidden': 46 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes) 47 | 48 | else: 49 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes) 50 | 51 | return input_nodes, output_nodes, blocks 52 | 53 | -------------------------------------------------------------------------------- /modules/conv_mod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import dgl.function as fn 7 | from dgl.nn.pytorch import TypedLinear 8 | 9 | 10 | class CustomLinear(nn.Linear): 11 | def reset_parameters(self): 12 | nn.init.xavier_normal_(self.weight) 13 | nn.init.zeros_(self.bias) 14 | -------------------------------------------------------------------------------- /modules/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl 3 | import os 4 | import numpy as np 5 | from dgl.data.utils import load_graphs 6 | from torch.utils.data import DataLoader as torch_dataloader 7 | from dgl.dataloading import DataLoader 8 | from sklearn.model_selection import train_test_split 9 | import torch.nn.functional as F 10 | import dgl.function as fn 11 | import logging 12 | import pickle 13 | import os 14 | 15 | 16 | def get_dataset(name: str, raw_dir: str, to_homo: bool=False, random_state: int=717): 17 | if name == 'yelp': 18 | yelp_data = dgl.data.FraudYelpDataset(raw_dir=raw_dir, random_seed=7537, verbose=False) 19 | graph = yelp_data[0] 20 | if to_homo: 21 | graph = dgl.to_homogeneous(graph, ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask']) 22 | graph = dgl.add_self_loop(graph) 23 | 24 | elif name == 'amazon': 25 | amazon_data = dgl.data.FraudAmazonDataset(raw_dir=raw_dir, random_seed=7537, verbose=False) 26 | graph = amazon_data[0] 27 | if to_homo: 28 | graph = dgl.to_homogeneous(graph, ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask']) 29 | graph = dgl.add_self_loop(graph) 30 | 31 | elif name == 'tsocial': 32 | t_social, _ = load_graphs(os.path.join(raw_dir, 'tsocial')) 33 | graph = t_social[0] 34 | graph.ndata['feature'] = graph.ndata['feature'].float() 35 | 36 | elif name == 'tfinance': 37 | t_finance, _ = load_graphs(os.path.join(raw_dir, 'tfinance')) 38 | graph = t_finance[0] 39 | graph.ndata['label'] = graph.ndata['label'].argmax(1) 40 | graph.ndata['feature'] = graph.ndata['feature'].float() 41 | 42 | else: 43 | raise 44 | 45 | return graph 46 | 47 | 48 | def get_index_loader_test(name: str, batch_size: int, unlabel_ratio: int=1, training_ratio: float=-1, 49 | shuffle_train: bool=True, to_homo:bool=False): 50 | assert name in ['yelp', 'amazon', 'tfinance', 'tsocial'], 'Invalid dataset name' 51 | 52 | graph = get_dataset(name, 'data/', to_homo=to_homo, random_state=7537) 53 | 54 | index = np.arange(graph.num_nodes()) 55 | labels = graph.ndata['label'] 56 | if name == 'amazon': 57 | index = np.arange(3305, graph.num_nodes()) 58 | 59 | train_nids, valid_test_nids = train_test_split(index, stratify=labels[index], 60 | train_size=training_ratio/100., random_state=2, shuffle=True) 61 | valid_nids, test_nids = train_test_split(valid_test_nids, stratify=labels[valid_test_nids], 62 | test_size=0.67, random_state=2, shuffle=True) 63 | 64 | train_mask = torch.zeros_like(labels).bool() 65 | val_mask = torch.zeros_like(labels).bool() 66 | test_mask = torch.zeros_like(labels).bool() 67 | 68 | train_mask[train_nids] = 1 69 | val_mask[valid_nids] = 1 70 | test_mask[test_nids] = 1 71 | 72 | graph.ndata['train_mask'] = train_mask 73 | graph.ndata['val_mask'] = val_mask 74 | graph.ndata['test_mask'] = test_mask 75 | 76 | labeled_nids = train_nids 77 | unlabeled_nids = np.concatenate([valid_nids, test_nids, train_nids]) 78 | 79 | power = 10 if name == 'tfinance' else 16 80 | 81 | valid_loader = torch_dataloader(valid_nids, batch_size=2**power, shuffle=False, drop_last=False, num_workers=4) 82 | test_loader = torch_dataloader(test_nids, batch_size=2**power, shuffle=False, drop_last=False, num_workers=4) 83 | labeled_loader = torch_dataloader(labeled_nids, batch_size=batch_size, shuffle=shuffle_train, drop_last=True, num_workers=0) 84 | unlabeled_loader = torch_dataloader(unlabeled_nids, batch_size=batch_size * unlabel_ratio, shuffle=shuffle_train, drop_last=True, num_workers=0) 85 | 86 | return graph, labeled_loader, valid_loader, test_loader, unlabeled_loader 87 | 88 | -------------------------------------------------------------------------------- /modules/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, precision_score, confusion_matrix, average_precision_score 3 | from scikitplot.helpers import binary_ks_curve 4 | import torch 5 | 6 | 7 | Tensor = torch.tensor 8 | 9 | 10 | def eval_auc_roc(pred, target): 11 | scores = roc_auc_score(target, pred) 12 | return scores 13 | 14 | 15 | def eval_auc_pr(pred, target): 16 | scores = average_precision_score(target, pred) 17 | return scores 18 | 19 | 20 | def eval_ks_statistics(target, pred): 21 | scores = binary_ks_curve(target, pred)[3] 22 | return scores 23 | 24 | 25 | def find_best_f1(probs, labels): 26 | best_f1, best_thre = -1., -1. 27 | thres_arr = np.linspace(0.05, 0.95, 19) 28 | for thres in thres_arr: 29 | preds = np.zeros_like(labels) 30 | preds[probs > thres] = 1 31 | mf1 = f1_score(labels, preds, average='macro') 32 | if mf1 > best_f1: 33 | best_f1 = mf1 34 | best_thre = thres 35 | return best_f1, best_thre 36 | 37 | 38 | def eval_pred(pred: Tensor, target: Tensor): 39 | s_pred = pred.cpu().detach().numpy() 40 | s_target = target.cpu().detach().numpy() 41 | 42 | auc_roc = roc_auc_score(s_target, s_pred) 43 | auc_pr = average_precision_score(s_target, s_pred) 44 | ks_statistics = eval_ks_statistics(s_target, s_pred) 45 | 46 | best_f1, best_thre = find_best_f1(s_pred, s_target) 47 | p_labels = (s_pred > best_thre).astype(int) 48 | accuracy = np.mean(s_target == p_labels) 49 | recall = recall_score(s_target, p_labels) 50 | precision = precision_score(s_target, p_labels) 51 | 52 | return auc_roc, auc_pr, ks_statistics, accuracy, \ 53 | recall, precision, best_f1, best_thre 54 | 55 | -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import modules.mod_utls as m_utls 4 | 5 | 6 | Tensor = torch.tensor 7 | 8 | 9 | def nll_loss(pred, target, pos_w: float=1.0): 10 | weight_tensor = torch.tensor([1., pos_w]).to(pred.device) 11 | loss_value = F.nll_loss(pred, target.long(), weight=weight_tensor) 12 | 13 | return loss_value, m_utls.to_np(loss_value) 14 | 15 | 16 | def nll_loss_raw(pred: Tensor, target: Tensor, pos_w, 17 | reduction: str='mean'): 18 | weight_tensor = torch.tensor([1., pos_w]).to(pred.device) 19 | loss_value = F.nll_loss(pred, target.long(), weight=weight_tensor, 20 | reduction=reduction) 21 | 22 | return loss_value 23 | 24 | 25 | def l2_regularization(model): 26 | l2_reg = torch.tensor(0., requires_grad=True) 27 | for key, value in model.named_parameters(): 28 | if len(value.shape) > 1 and 'weight' in key: 29 | l2_reg = l2_reg + torch.sum(value ** 2) * 0.5 30 | return l2_reg 31 | 32 | 33 | -------------------------------------------------------------------------------- /modules/mod_utls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import dgl 5 | import os 6 | import math 7 | import pickle 8 | from sklearn.metrics import f1_score 9 | 10 | 11 | def to_np(x): 12 | return x.cpu().detach().numpy() 13 | 14 | 15 | def store_model(my_model, args): 16 | file_path = os.path.join('model-weights', 17 | args['data-set'] + '.pth') 18 | torch.save(my_model.state_dict(), file_path) 19 | -------------------------------------------------------------------------------- /modules/mr_conv_mod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import dgl.function as fn 7 | from dgl.nn.pytorch import TypedLinear 8 | from modules.conv_mod import CustomLinear 9 | 10 | 11 | def build_mlp(in_dim: int, out_dim: int, p: float, hid_dim: int=64, final_act: bool=True): 12 | mlp_list = [] 13 | 14 | mlp_list.append(CustomLinear(in_dim, hid_dim, bias=True)) 15 | mlp_list.append(nn.ELU()) 16 | mlp_list.append(nn.Dropout(p=p)) 17 | mlp_list.append(nn.LayerNorm(hid_dim)) 18 | mlp_list.append(CustomLinear(hid_dim, out_dim, bias=True)) 19 | if final_act: 20 | mlp_list.append(nn.ELU()) 21 | mlp_list.append(nn.Dropout(p=p)) 22 | 23 | return nn.Sequential(*mlp_list) 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | _libgcc_mutex==0.1 2 | _openmp_mutex==5.1 3 | abseil-cpp==20211102.0 4 | absl-py==1.3.0 5 | aiohttp==3.8.3 6 | aiosignal==1.2.0 7 | anyio==3.6.2 8 | argon2-cffi==21.3.0 9 | argon2-cffi-bindings==21.2.0 10 | arrow-cpp==8.0.0 11 | async-timeout==4.0.2 12 | asynctest==0.13.0 13 | attrs==22.2.0 14 | aws-c-common==0.4.57 15 | aws-c-event-stream==0.1.6 16 | aws-checksums==0.1.9 17 | aws-sdk-cpp==1.8.185 18 | babel==2.11.0 19 | backcall==0.2.0 20 | backports==1.0 21 | backports.functools_lru_cache==1.6.4 22 | beautifulsoup4==4.11.1 23 | blas==1.0 24 | bleach==5.0.1 25 | blinker==1.4 26 | boost-cpp==1.73.0 27 | bottleneck==1.3.5 28 | brotli==1.0.9 29 | brotli-bin==1.0.9 30 | brotlipy==0.7.0 31 | bzip2==1.0.8 32 | c-ares==1.19.0 33 | ca-certificates==2023.05.30 34 | cachetools==5.3.0 35 | certifi==2022.12.7 36 | cffi==1.15.1 37 | charset-normalizer==2.0.4 38 | click==8.1.3 39 | cryptography==38.0.1 40 | cuda==11.7.1 41 | cuda-cccl==11.7.91 42 | cuda-command-line-tools==11.7.1 43 | cuda-compiler==11.7.1 44 | cuda-cudart==11.7.99 45 | cuda-cudart-dev==11.7.99 46 | cuda-cuobjdump==11.7.91 47 | cuda-cupti==11.7.101 48 | cuda-cuxxfilt==11.7.91 49 | cuda-demo-suite==12.0.76 50 | cuda-documentation==12.0.76 51 | cuda-driver-dev==11.7.99 52 | cuda-gdb==12.0.90 53 | cuda-libraries==11.7.1 54 | cuda-libraries-dev==11.7.1 55 | cuda-memcheck==11.8.86 56 | cuda-nsight==12.0.78 57 | cuda-nsight-compute==12.0.0 58 | cuda-nvcc==11.7.99 59 | cuda-nvdisasm==12.0.76 60 | cuda-nvml-dev==11.7.91 61 | cuda-nvprof==12.0.90 62 | cuda-nvprune==11.7.91 63 | cuda-nvrtc==11.7.99 64 | cuda-nvrtc-dev==11.7.99 65 | cuda-nvtx==11.7.91 66 | cuda-nvvp==12.0.90 67 | cuda-runtime==11.7.1 68 | cuda-sanitizer-api==12.0.90 69 | cuda-toolkit==11.7.1 70 | cuda-tools==11.7.1 71 | cuda-visual-tools==11.7.1 72 | cycler==0.11.0 73 | dbus==1.13.18 74 | decorator==5.1.1 75 | defusedxml==0.7.1 76 | dgl==1.1.0.cu118 77 | docker-pycreds==0.4.0 78 | entrypoints==0.4 79 | expat==2.4.9 80 | ffmpeg==4.3 81 | fftw==3.3.9 82 | flit-core==3.6.0 83 | fontconfig==2.14.1 84 | fonttools==4.25.0 85 | freetype==2.12.1 86 | frozenlist==1.3.3 87 | gdb==11.2 88 | gds-tools==1.5.0.59 89 | gflags==2.2.2 90 | giflib==5.2.1 91 | gitdb==4.0.10 92 | gitpython==3.1.29 93 | glib==2.69.1 94 | glog==0.5.0 95 | gmp==6.2.1 96 | gmpy2==2.1.2 97 | gnutls==3.6.15 98 | google-api-core==2.11.0 99 | google-api-python-client==2.83.0 100 | google-auth==2.17.1 101 | google-auth-httplib2==0.1.0 102 | google-auth-oauthlib==1.0.0 103 | googleapis-common-protos==1.59.0 104 | grpc-cpp==1.46.1 105 | grpcio==1.42.0 106 | gst-plugins-base==1.14.0 107 | gstreamer==1.14.0 108 | httplib2==0.22.0 109 | icu==58.2 110 | idna==3.4 111 | importlib-metadata==4.11.4 112 | importlib_resources==5.10.1 113 | intel-openmp==2021.4.0 114 | ipykernel==5.5.5 115 | ipython==7.33.0 116 | ipython_genutils==0.2.0 117 | jedi==0.18.2 118 | jinja2==3.1.2 119 | joblib==1.1.1 120 | jpeg==9e 121 | json5==0.9.5 122 | jsonschema==4.17.3 123 | jupyter_client==7.0.6 124 | jupyter_core==4.11.2 125 | jupyter_server==1.23.4 126 | jupyterlab==3.5.2 127 | jupyterlab_pygments==0.2.2 128 | jupyterlab_server==2.17.0 129 | kiwisolver==1.4.4 130 | krb5==1.19.2 131 | lame==3.100 132 | lcms2==2.12 133 | ld_impl_linux-64==2.38 134 | lerc==3.0 135 | libboost==1.73.0 136 | libbrotlicommon==1.0.9 137 | libbrotlidec==1.0.9 138 | libbrotlienc==1.0.9 139 | libclang==10.0.1 140 | libcublas==11.10.3.66 141 | libcublas-dev==11.10.3.66 142 | libcufft==10.7.2.124 143 | libcufft-dev==10.7.2.124 144 | libcufile==1.5.0.59 145 | libcufile-dev==1.5.0.59 146 | libcurand==10.3.1.50 147 | libcurand-dev==10.3.1.50 148 | libcurl==7.87.0 149 | libcusolver==11.4.0.1 150 | libcusolver-dev==11.4.0.1 151 | libcusparse==11.7.4.91 152 | libcusparse-dev==11.7.4.91 153 | libdeflate==1.8 154 | libedit==3.1.20221030 155 | libev==4.33 156 | libevent==2.1.12 157 | libffi==3.4.2 158 | libgcc-ng==11.2.0 159 | libgfortran-ng==11.2.0 160 | libgfortran5==11.2.0 161 | libgomp==11.2.0 162 | libiconv==1.16 163 | libidn2==2.3.2 164 | libllvm10==10.0.1 165 | libnghttp2==1.52.0 166 | libnpp==11.7.4.75 167 | libnpp-dev==11.7.4.75 168 | libnvjpeg==11.8.0.2 169 | libnvjpeg-dev==11.8.0.2 170 | libpng==1.6.37 171 | libpq==12.9 172 | libprotobuf==3.20.3 173 | libsodium==1.0.18 174 | libssh2==1.10.0 175 | libstdcxx-ng==11.2.0 176 | libtasn1==4.16.0 177 | libthrift==0.15.0 178 | libtiff==4.4.0 179 | libunistring==0.9.10 180 | libuuid==1.41.5 181 | libwebp==1.2.4 182 | libwebp-base==1.2.4 183 | libxcb==1.15 184 | libxkbcommon==1.0.1 185 | libxml2==2.9.14 186 | libxslt==1.1.35 187 | lz4-c==1.9.4 188 | markdown==3.4.1 189 | markupsafe==2.1.1 190 | matplotlib==3.5.2 191 | matplotlib-base==3.5.2 192 | matplotlib-inline==0.1.6 193 | mistune==2.0.4 194 | mkl==2021.4.0 195 | mkl-service==2.4.0 196 | mkl_fft==1.3.1 197 | mkl_random==1.2.2 198 | mpc==1.1.0 199 | mpfr==4.0.2 200 | mpmath==1.2.1 201 | multidict==6.0.2 202 | munkres==1.1.4 203 | nbclassic==0.4.8 204 | nbclient==0.6.8 205 | nbconvert==7.2.7 206 | nbconvert-core==7.2.7 207 | nbconvert-pandoc==7.2.7 208 | nbformat==5.7.1 209 | ncurses==6.3 210 | nest-asyncio==1.5.6 211 | nettle==3.7.3 212 | networkx==2.2 213 | notebook==6.5.2 214 | notebook-shim==0.2.2 215 | nsight-compute==2022.4.0.15 216 | nspr==4.33 217 | nss==3.74 218 | numexpr==2.8.4 219 | numpy==1.21.5 220 | numpy-base==1.21.5 221 | oauthlib==3.2.2 222 | openh264==2.1.1 223 | openssl==1.1.1u 224 | opt-einsum==3.3.0 225 | orc==1.7.4 226 | packaging==22.0 227 | pandas==1.3.5 228 | pandoc==2.19.2 229 | pandocfilters==1.5.0 230 | parso==0.8.3 231 | pathtools==0.1.2 232 | pcre==8.45 233 | pexpect==4.8.0 234 | pickleshare==0.7.5 235 | pillow==9.3.0 236 | pip==22.3.1 237 | pkgutil-resolve-name==1.3.10 238 | ply==3.11 239 | progress==1.5 240 | prometheus_client==0.15.0 241 | promise==2.3 242 | prompt-toolkit==3.0.36 243 | protobuf==4.21.12 244 | psutil==5.9.0 245 | ptyprocess==0.7.0 246 | pyarrow==8.0.0 247 | pyasn1==0.4.8 248 | pyasn1-modules==0.2.8 249 | pycparser==2.21 250 | pyg==2.2.0 251 | pygments==2.13.0 252 | pyjwt==2.4.0 253 | pyopenssl==22.0.0 254 | pyparsing==3.0.9 255 | pyqt==5.15.7 256 | pyqt5-sip==12.11.0 257 | pyro-api==0.1.2 258 | pyro-ppl==1.8.4 259 | pyrsistent==0.18.0 260 | pysocks==1.7.1 261 | python==3.7.15 262 | python-dateutil==2.8.2 263 | python-fastjsonschema==2.16.2 264 | python_abi==3.7 265 | pytorch==1.13.1 266 | pytorch-cluster==1.6.0 267 | pytorch-cuda==11.7 268 | pytorch-mutex==1.0 269 | pytorch-scatter==2.1.0 270 | pytorch-sparse==0.6.16 271 | pytz==2022.7 272 | pyyaml==6.0 273 | pyzmq==19.0.2 274 | qt-main==5.15.2 275 | qt-webengine==5.15.9 276 | qtwebkit==5.212 277 | re2==2022.04.01 278 | readline==8.2 279 | requests==2.28.1 280 | requests-oauthlib==1.3.1 281 | rsa==4.9 282 | scikit-learn==1.0.2 283 | scikit-plot==0.3.7 284 | scipy==1.7.3 285 | seaborn==0.12.2 286 | send2trash==1.8.0 287 | sentry-sdk==1.12.1 288 | setproctitle==1.3.2 289 | setuptools==65.5.0 290 | shortuuid==1.0.11 291 | sip==6.6.2 292 | six==1.16.0 293 | smmap==5.0.0 294 | snappy==1.1.9 295 | sniffio==1.3.0 296 | soupsieve==2.3.2.post1 297 | sqlite==3.40.0 298 | sympy==1.10.1 299 | tensorboard==2.6.0 300 | tensorboard-data-server==0.6.1 301 | tensorboard-plugin-wit==1.8.1 302 | tensorboardx==2.2 303 | terminado==0.17.1 304 | threadpoolctl==2.2.0 305 | tinycss2==1.2.1 306 | tk==8.6.12 307 | toml==0.10.2 308 | tomli==2.0.1 309 | torchaudio==0.13.1 310 | torchvision==0.14.1 311 | tornado==6.1 312 | tqdm==4.64.1 313 | traitlets==5.8.0 314 | typing-extensions==4.4.0 315 | typing_extensions==4.4.0 316 | uritemplate==4.1.1 317 | urllib3==1.26.13 318 | utf8proc==2.6.1 319 | wandb==0.13.7 320 | wcwidth==0.2.5 321 | webencodings==0.5.1 322 | websocket-client==1.4.2 323 | werkzeug==2.2.2 324 | wheel==0.37.1 325 | xz==5.2.8 326 | yarl==1.8.1 327 | zeromq==4.3.4 328 | zipp==3.11.0 329 | zlib==1.2.13 330 | zstd==1.5.2 331 | --------------------------------------------------------------------------------