├── .gitignore
├── LICENSE
├── README.md
├── config
├── amazon.yml
├── tfinance.yml
├── tsocial.yml
└── yelp.yml
├── data
└── dataset source.txt
├── framework.png
├── main.py
├── model-weights
├── amazon.pth
├── tfinance.pth
└── yelp.pth
├── models.py
├── modules
├── aux_mod.py
├── conv_mod.py
├── data_loader.py
├── evaluation.py
├── loss.py
├── mod_utls.py
└── mr_conv_mod.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .DS_Store
161 | .idea/
162 |
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023-2024 Xtra Computing Group, NUS, Singapore.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction:
2 | This is the code for the ICLR 2024 paper of **ConsisGAD**: [Consistency Training with Learnable Data Augmentation for Graph Anomaly Detection with Limited Supervision.](https://openreview.net/forum?id=elMKXvhhQ9)
3 |
4 | In this work, we propose a novel framework, ConsisGAD, which is tailored for graph anomaly detection in scenarios characterized by limited supervision and is anchored in the principles of consistency training. Under limited supervision, ConsisGAD effectively leverages the abundance of unlabeled data for consistency training by incorporating a novel learnable data augmentation mechanism, thereby introducing controlled noise into the dataset. Moreover, ConsisGAD takes advantage of the variance in homophily distribution between normal and anomalous nodes to craft a simplified GNN backbone, enhancing its capability to effectively distinguish between these two classes. A brief overview of our framework is illustrated in the following picture.
5 |
6 |
7 |
8 |
9 |
10 | This repository contains the source code for our Graph Neural Network (GNN) backbone, consistency training procedure, and learnable data augmentation module. Below is an overview of the key components and their locations within the repository:
11 |
12 | - **GNN Backbone Model**: The core implementation of our GNN backbone model is encapsulated within the `simpleGNN_MR` class located in the `models.py` file.
13 |
14 | - **Consistency Training Procedure**: The consistency training procedure is implemented through the `UDA_train_epoch` function, which can be found in the `main.py` file.
15 |
16 | - **Learnable Data Augmentation**: Our learnable data augmentation is realized via the `SoftAttentionDrop` class, which is also located in the `main.py` file.
17 |
18 | ## Directory Structure
19 | The repository is organized into several directories, each serving a specific purpose:
20 |
21 | - `data/`: This directory houses the datasets utilized in our work.
22 |
23 | - `config/`: This folder stores the hyper-parameter configuration of our model.
24 |
25 | - `modules/`: Auxiliary components of our model are stored in this directory. It includes important modules, such as the data loader `data_loader.py` and the evaluation pipeline `evaluation.py`.
26 |
27 | - `model-weights/`: Here, we store the trained weights of our model.
28 |
29 | ## Installation:
30 | - Install required packages: `pip install -r requirements.txt`
31 | - Dataset resources:
32 | - For Amazon and YelpChi, we use the built-in datasets in the DGL package https://docs.dgl.ai/en/0.8.x/api/python/dgl.data.html.
33 | - For T-Finance and T-Social, we download the datasets from https://github.com/squareRoot3/Rethinking-Anomaly-Detection.
34 | - Please download and unzip all the files in the `data/` folder.
35 |
36 | ## Usage:
37 | - Hyper-parameter settings for all datasets are put into the `config/` folder.
38 | - To run the model, use `--config` to specify hyper-parameters and `--runs` the number of running times.
39 | - If you want to run the YelpChi dataset 5 times, please execute this command: `python main.py --config 'config/yelp.yml' --runs 5`.
40 |
41 | ## Citation
42 | If you find our work useful, please cite:
43 |
44 | ```
45 | @inproceedings{
46 | chen2024consistency,
47 | title={Consistency Training with Learnable Data Augmentation for Graph Anomaly Detection with Limited Supervision},
48 | author={Nan Chen and Zemin Liu and Bryan Hooi and Bingsheng He and Rizal Fathony and Jun Hu and Jia Chen},
49 | booktitle={The Twelfth International Conference on Learning Representations},
50 | year={2024},
51 | url={https://openreview.net/forum?id=elMKXvhhQ9}
52 | }
53 | ```
54 |
55 | Feel free to contact nanchansysu@gmail.com if you have any questions.
56 |
--------------------------------------------------------------------------------
/config/amazon.yml:
--------------------------------------------------------------------------------
1 | data-set: 'amazon'
2 | to-homo: False
3 | shuffle-train: True
4 | model: 'backbone'
5 | hidden-dim: 64
6 | num-layers: 1
7 | epochs: 100
8 | lr: 0.001
9 | weight-decay: 0.00001
10 | device: 1
11 | training-ratio: 1
12 | train-procedure: 'CT'
13 | mlp-drop: 0.3
14 | input-drop: 0.0
15 | hidden-drop: 0.0
16 | mlp12-dim: 128
17 | mlp3-dim: 128
18 | bn-type: 2
19 | optim: 'adam'
20 | store-model: True
21 | trainable-consis-weight: 1.5
22 | trainable-temp: 0.0001
23 | trainable-eps: 0.000000000001
24 | trainable-drop-rate: 0.2
25 | trainable-warm-up: -1
26 | trainable-model: 'proj'
27 | trainable-optim: 'adam'
28 | trainable-lr: 0.01
29 | trainable-weight-decay: 0.0
30 | topk-mode: 4
31 | diversity-type: 'euc'
32 | unlabel-ratio: 6
33 | normal-th: 5
34 | fraud-th: 85
35 | trainable-detach-y: True
36 | trainable-div-eps: True
37 | trainable-detach-mask: False
38 | batch-size: 32
39 | train-iterations: 128
40 |
--------------------------------------------------------------------------------
/config/tfinance.yml:
--------------------------------------------------------------------------------
1 | data-set: 'tfinance'
2 | to-homo: False
3 | shuffle-train: True
4 | model: 'backbone'
5 | hidden-dim: 64
6 | num-layers: 1
7 | epochs: 100
8 | lr: 0.001
9 | weight-decay: 0.00001
10 | device: 3
11 | training-ratio: 1
12 | train-procedure: 'CT'
13 | mlp-drop: 0.5
14 | input-drop: 0.0
15 | hidden-drop: 0.0
16 | mlp12-dim: 64
17 | mlp3-dim: 128
18 | bn-type: 2
19 | optim: 'adam'
20 | store-model: True
21 | trainable-consis-weight: 1.0
22 | trainable-temp: 0.0001
23 | trainable-eps: 0.000000000001
24 | trainable-drop-rate: 0.2
25 | trainable-warm-up: -1
26 | trainable-model: 'proj'
27 | trainable-optim: 'adam'
28 | trainable-lr: 0.005
29 | trainable-weight-decay: 0.0
30 | topk-mode: 4
31 | diversity-type: 'euc'
32 | unlabel-ratio: 4
33 | normal-th: 5
34 | fraud-th: 88
35 | trainable-detach-y: True
36 | trainable-div-eps: True
37 | trainable-detach-mask: False
38 | batch-size: 128
39 | train-iterations: 128
40 |
41 |
--------------------------------------------------------------------------------
/config/tsocial.yml:
--------------------------------------------------------------------------------
1 | data-set: 'tsocial'
2 | to-homo: False
3 | shuffle-train: True
4 | model: 'backbone'
5 | hidden-dim: 64
6 | num-layers: 1
7 | epochs: 100
8 | lr: 0.001
9 | weight-decay: 0.00001
10 | device: 4
11 | training-ratio: 0.01
12 | train-procedure: 'CT'
13 | mlp-drop: 0.4
14 | input-drop: 0.0
15 | hidden-drop: 0.0
16 | mlp12-dim: 128
17 | mlp3-dim: 128
18 | bn-type: 2
19 | optim: 'adam'
20 | store-model: True
21 | trainable-consis-weight: 1.5
22 | trainable-temp: 0.0001
23 | trainable-eps: 0.000000000001
24 | trainable-drop-rate: 0.2
25 | trainable-warm-up: -1
26 | trainable-model: 'proj'
27 | trainable-optim: 'adam'
28 | trainable-lr: 0.005
29 | trainable-weight-decay: 0.0
30 | topk-mode: 4
31 | diversity-type: 'euc'
32 | unlabel-ratio: 5
33 | normal-th: 5
34 | fraud-th: 88
35 | trainable-detach-y: True
36 | trainable-div-eps: True
37 | trainable-detach-mask: False
38 | batch-size: 128
39 | train-iterations: 128
40 |
--------------------------------------------------------------------------------
/config/yelp.yml:
--------------------------------------------------------------------------------
1 | data-set: 'yelp'
2 | to-homo: False
3 | shuffle-train: True
4 | model: 'backbone'
5 | hidden-dim: 64
6 | num-layers: 1
7 | epochs: 100
8 | lr: 0.001
9 | weight-decay: 0.00001
10 | device: 2
11 | training-ratio: 1
12 | train-procedure: 'CT'
13 | mlp-drop: 0.4
14 | input-drop: 0.0
15 | hidden-drop: 0.0
16 | mlp12-dim: 128
17 | mlp3-dim: 128
18 | bn-type: 2
19 | optim: 'adam'
20 | store-model: True
21 | trainable-consis-weight: 1.5
22 | trainable-temp: 0.0001
23 | trainable-eps: 0.000000000001
24 | trainable-drop-rate: 0.2
25 | trainable-warm-up: -1
26 | trainable-model: 'mlp'
27 | trainable-optim: 'adam'
28 | trainable-lr: 0.005
29 | trainable-weight-decay: 0.00001
30 | topk-mode: 4
31 | diversity-type: 'cos'
32 | unlabel-ratio: 4
33 | normal-th: 7
34 | fraud-th: 88
35 | trainable-detach-y: True
36 | trainable-div-eps: True
37 | trainable-detach-mask: False
38 | batch-size: 128
39 | train-iterations: 128
40 |
--------------------------------------------------------------------------------
/data/dataset source.txt:
--------------------------------------------------------------------------------
1 | For Yelp and Amazon, please refer to https://docs.dgl.ai/en/0.8.x/api/python/dgl.data.html.
2 |
3 | For T-Social and T-Finance, please refer to https://github.com/squareRoot3/Rethinking-Anomaly-Detection.
4 |
--------------------------------------------------------------------------------
/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/framework.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import csv
5 | import time
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from modules.data_loader import get_index_loader_test
11 | from models import simpleGNN_MR
12 | import modules.mod_utls as m_utls
13 | from modules.loss import nll_loss, l2_regularization, nll_loss_raw
14 | from modules.evaluation import eval_pred
15 | from modules.aux_mod import fixed_augmentation
16 | from sklearn.metrics import f1_score
17 | from modules.conv_mod import CustomLinear
18 | from modules.mr_conv_mod import build_mlp
19 | import numpy as np
20 | from numpy import random
21 | import math
22 | import pandas as pd
23 | from functools import partial
24 | import dgl
25 | import warnings
26 | import wandb
27 | import yaml
28 | warnings.filterwarnings("ignore")
29 |
30 |
31 | class SoftAttentionDrop(nn.Module):
32 | def __init__(self, args):
33 | super(SoftAttentionDrop, self).__init__()
34 | dim = args['hidden-dim']
35 |
36 | self.temp = args['trainable-temp']
37 | self.p = args['trainable-drop-rate']
38 | if args['trainable-model'] == 'proj':
39 | self.mask_proj = CustomLinear(dim, dim)
40 | else:
41 | self.mask_proj = build_mlp(in_dim=dim, out_dim=dim, p=args['mlp-drop'], final_act=False)
42 |
43 | self.detach_y = args['trainable-detach-y']
44 | self.div_eps = args['trainable-div-eps']
45 | self.detach_mask = args['trainable-detach-mask']
46 |
47 | def forward(self, feature, in_eval=False):
48 | mask = self.mask_proj(feature)
49 |
50 | y = torch.zeros_like(mask)
51 | k = round(mask.shape[1] * self.p)
52 |
53 | for _ in range(k):
54 | if self.detach_y:
55 | w = torch.zeros_like(y)
56 | w[y>0.5] = 1
57 | w = (1. - w).detach()
58 | else:
59 | w = (1. - y)
60 |
61 | logw = torch.log(w + 1e-12)
62 | y1 = (mask + logw) / self.temp
63 | y1 = y1 - torch.amax(y1, dim=1, keepdim=True)
64 |
65 | if self.div_eps:
66 | y1 = torch.exp(y1) / (torch.sum(torch.exp(y1), dim=1, keepdim=True) + args['trainable-eps'])
67 | else:
68 | y1 = torch.exp(y1) / torch.sum(torch.exp(y1), dim=1, keepdim=True)
69 |
70 | y = y + y1 * w
71 |
72 | mask = 1. - y
73 | mask = mask / (1. - self.p)
74 |
75 | if in_eval and self.detach_mask:
76 | mask = mask.detach()
77 |
78 | return feature * mask
79 |
80 |
81 | def create_model(args, e_ts):
82 | if args['model'] == 'backbone':
83 | tmp_model = simpleGNN_MR(in_feats=args['node-in-dim'], hidden_feats=args['hidden-dim'], out_feats=args['node-out-dim'],
84 | num_layers=args['num-layers'], e_types=e_ts, input_drop=args['input-drop'], hidden_drop=args['hidden-drop'],
85 | mlp_drop=args['mlp-drop'], mlp12_dim=args['mlp12-dim'], mlp3_dim=args['mlp3-dim'], bn_type=args['bn-type'])
86 | else:
87 | raise
88 | tmp_model.to(args['device'])
89 |
90 | return tmp_model
91 |
92 |
93 | def UDA_train_epoch(epoch, model, loss_func, graph, label_loader, unlabel_loader, optimizer, augmentor, args):
94 | model.train()
95 | num_iters = args['train-iterations']
96 |
97 | sampler, attn_drop, ad_optim = augmentor
98 |
99 | unlabel_loader_iter = iter(unlabel_loader)
100 | label_loader_iter = iter(label_loader)
101 |
102 | for idx in range(num_iters):
103 | try:
104 | label_idx = label_loader_iter.__next__()
105 | except:
106 | label_loader_iter = iter(label_loader)
107 | label_idx = label_loader_iter.__next__()
108 | try:
109 | unlabel_idx = unlabel_loader_iter.__next__()
110 | except:
111 | unlabel_loader_iter = iter(unlabel_loader)
112 | unlabel_idx = unlabel_loader_iter.__next__()
113 |
114 | if epoch > args['trainable-warm-up']:
115 | model.eval()
116 | with torch.no_grad():
117 | _, _, u_blocks = fixed_augmentation(graph, unlabel_idx.to(args['device']), sampler, aug_type='none')
118 | weak_inter_results = model(u_blocks, update_bn=False, return_logits=True)
119 | weak_h = torch.stack(weak_inter_results, dim=1)
120 | weak_h = weak_h.reshape(weak_h.shape[0], -1)
121 | weak_logits = model.proj_out(weak_h)
122 | u_pred_weak_log = weak_logits.log_softmax(dim=-1)
123 | u_pred_weak = u_pred_weak_log.exp()[:, 1]
124 |
125 | pseudo_labels = torch.ones_like(u_pred_weak).long()
126 | neg_tar = (u_pred_weak <= (args['normal-th']/100.)).bool()
127 | pos_tar = (u_pred_weak >= (args['fraud-th']/100.)).bool()
128 | pseudo_labels[neg_tar] = 0
129 | pseudo_labels[pos_tar] = 1
130 | u_mask = torch.logical_or(neg_tar, pos_tar)
131 |
132 | model.train()
133 | attn_drop.train()
134 | for param in model.parameters():
135 | param.requires_grad = False
136 | for param in attn_drop.parameters():
137 | param.requires_grad = True
138 |
139 | _, _, u_blocks = fixed_augmentation(graph, unlabel_idx.to(args['device']), sampler, aug_type='drophidden')
140 |
141 | inter_results = model(u_blocks, update_bn=False, return_logits=True)
142 | dropped_results = [inter_results[0]]
143 | for i in range(1, len(inter_results)):
144 | dropped_results.append(attn_drop(inter_results[i]))
145 | h = torch.stack(dropped_results, dim=1)
146 | h = h.reshape(h.shape[0], -1)
147 | logits = model.proj_out(h)
148 | u_pred = logits.log_softmax(dim=-1)
149 |
150 | consistency_loss = nll_loss_raw(u_pred, pseudo_labels, pos_w=1.0, reduction='none')
151 | consistency_loss = torch.mean(consistency_loss * u_mask)
152 |
153 | if args['diversity-type'] == 'cos':
154 | diversity_loss = F.cosine_similarity(weak_h, h, dim=-1)
155 | elif args['diversity-type'] == 'euc':
156 | diversity_loss = F.pairwise_distance(weak_h, h)
157 | else:
158 | raise
159 | diversity_loss = torch.mean(diversity_loss * u_mask)
160 |
161 | total_loss = args['trainable-consis-weight'] * consistency_loss - diversity_loss + args['trainable-weight-decay'] * l2_regularization(attn_drop)
162 |
163 | ad_optim.zero_grad()
164 | total_loss.backward()
165 | ad_optim.step()
166 |
167 | for param in model.parameters():
168 | param.requires_grad = True
169 | for param in attn_drop.parameters():
170 | param.requires_grad = False
171 |
172 | inter_results = model(u_blocks, update_bn=False, return_logits=True)
173 | dropped_results = [inter_results[0]]
174 | for i in range(1, len(inter_results)):
175 | dropped_results.append(attn_drop(inter_results[i], in_eval=True))
176 |
177 | h = torch.stack(dropped_results, dim=1)
178 | h = h.reshape(h.shape[0], -1)
179 | logits = model.proj_out(h)
180 | u_pred = logits.log_softmax(dim=-1)
181 |
182 | unsup_loss = nll_loss_raw(u_pred, pseudo_labels, pos_w=1.0, reduction='none')
183 | unsup_loss = torch.mean(unsup_loss * u_mask)
184 | else:
185 | unsup_loss = 0.0
186 |
187 | _, _, s_blocks = fixed_augmentation(graph, label_idx.to(args['device']), sampler, aug_type='none')
188 | s_pred = model(s_blocks)
189 | s_target = s_blocks[-1].dstdata['label']
190 |
191 | sup_loss, _ = loss_func(s_pred, s_target)
192 |
193 | loss = sup_loss + unsup_loss + args['weight-decay'] * l2_regularization(model)
194 |
195 | optimizer.zero_grad()
196 | loss.backward()
197 | optimizer.step()
198 |
199 |
200 | def get_model_pred(model, graph, data_loader, sampler, args):
201 | model.eval()
202 |
203 | pred_list = []
204 | target_list = []
205 | with torch.no_grad():
206 | for node_idx in data_loader:
207 | _, _, blocks = sampler.sample_blocks(graph, node_idx.to(args['device']))
208 |
209 | pred = model(blocks)
210 | target = blocks[-1].dstdata['label']
211 |
212 | pred_list.append(pred.detach())
213 | target_list.append(target.detach())
214 | pred_list = torch.cat(pred_list, dim=0)
215 | target_list = torch.cat(target_list, dim=0)
216 | pred_list = pred_list.exp()[:, 1]
217 |
218 | return pred_list, target_list
219 |
220 |
221 | def val_epoch(epoch, model, graph, valid_loader, test_loader, sampler, args):
222 | valid_dict = {}
223 | valid_pred, valid_target = get_model_pred(model, graph, valid_loader, sampler, args)
224 | v_roc, v_pr, _, _, _, _, v_f1, v_thre = eval_pred(valid_pred, valid_target)
225 | valid_dict['auc-roc'] = v_roc
226 | valid_dict['auc-pr'] = v_pr
227 | valid_dict['marco f1'] = v_f1
228 |
229 | test_dict = {}
230 | test_pred, test_target = get_model_pred(model, graph, test_loader, sampler, args)
231 | t_roc, t_pr, _, _, _, _, _, _ = eval_pred(test_pred, test_target)
232 | test_dict['auc-roc'] = t_roc
233 | test_dict['auc-pr'] = t_pr
234 |
235 | test_pred = test_pred.cpu().numpy()
236 | test_target = test_target.cpu().numpy()
237 | guessed_target = np.zeros_like(test_target)
238 | guessed_target[test_pred > v_thre] = 1
239 | t_f1 = f1_score(test_target, guessed_target, average='macro')
240 | test_dict['marco f1'] = t_f1
241 |
242 | return valid_dict, test_dict
243 |
244 |
245 | def run_model(args):
246 | graph, label_loader, valid_loader, test_loader, unlabel_loader = get_index_loader_test(name=args['data-set'],
247 | batch_size=args['batch-size'],
248 | unlabel_ratio=args['unlabel-ratio'],
249 | training_ratio=args['training-ratio'],
250 | shuffle_train=args['shuffle-train'],
251 | to_homo=args['to-homo'])
252 | graph = graph.to(args['device'])
253 |
254 | args['node-in-dim'] = graph.ndata['feature'].shape[1]
255 | args['node-out-dim'] = 2
256 |
257 | my_model = create_model(args, graph.etypes)
258 |
259 | if args['optim'] == 'adam':
260 | optimizer = optim.Adam(my_model.parameters(), lr=args['lr'], weight_decay=0.0)
261 | elif args['optim'] == 'rmsprop':
262 | optimizer = optim.RMSprop(my_model.parameters(), lr=args['lr'], weight_decay=0.0)
263 |
264 | sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['num-layers'])
265 |
266 | train_epoch = UDA_train_epoch
267 | attn_drop = SoftAttentionDrop(args).to(args['device'])
268 | if args['trainable-optim'] == 'rmsprop':
269 | ad_optim = optim.RMSprop(attn_drop.parameters(), lr=args['trainable-lr'], weight_decay=0.0)
270 | else:
271 | ad_optim = optim.Adam(attn_drop.parameters(), lr=args['trainable-lr'], weight_decay=0.0)
272 | augmentor = (sampler, attn_drop, ad_optim)
273 |
274 | task_loss = nll_loss
275 |
276 | best_val = sys.float_info.min
277 | for epoch in range(args['epochs']):
278 | train_epoch(epoch, my_model, task_loss, graph, label_loader, unlabel_loader, optimizer, augmentor, args)
279 | val_results, test_results = val_epoch(epoch, my_model, graph, valid_loader, test_loader, sampler, args)
280 |
281 | if val_results['auc-roc'] > best_val:
282 | best_val = val_results['auc-roc']
283 | test_in_best_val = test_results
284 |
285 | if args['store-model']:
286 | m_utls.store_model(my_model, args)
287 |
288 | return list(test_in_best_val.values())
289 |
290 |
291 | def get_config(config_path="config.yml"):
292 | with open(config_path, "r") as setting:
293 | config = yaml.load(setting, Loader=yaml.FullLoader)
294 | return config
295 |
296 |
297 | if __name__ == '__main__':
298 | start_time = time.time()
299 |
300 | parser = argparse.ArgumentParser()
301 | parser.add_argument('--config', required=True, type=str, help='Path to the config file.')
302 | parser.add_argument('--runs', type=int, default=1, help='Number of runs. Default is 1.')
303 | cfg = vars(parser.parse_args())
304 |
305 | args = get_config(cfg['config'])
306 | if torch.cuda.is_available():
307 | args['device'] = torch.device('cuda:%d'%(args['device']))
308 | else:
309 | args['device'] = torch.device('cpu')
310 |
311 | print(args)
312 | final_results = []
313 | for r in range(cfg['runs']):
314 | final_results.append(run_model(args))
315 |
316 | final_results = np.array(final_results)
317 | mean_results = np.mean(final_results, axis=0)
318 | std_results = np.std(final_results, axis=0)
319 |
320 | print(mean_results)
321 | print(std_results)
322 | print('total time: ', time.time()-start_time)
323 |
--------------------------------------------------------------------------------
/model-weights/amazon.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/model-weights/amazon.pth
--------------------------------------------------------------------------------
/model-weights/tfinance.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/model-weights/tfinance.pth
--------------------------------------------------------------------------------
/model-weights/yelp.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xtra-Computing/ConsisGAD/36811c5bc79be49c9740f25a1f260496bb4736af/model-weights/yelp.pth
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 | from typing import Optional, Tuple, Union
3 | import dgl
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import dgl.function as fn
8 | import modules.mod_utls as m_utls
9 | import numpy as np
10 | from modules.conv_mod import CustomLinear
11 | from modules.mr_conv_mod import build_mlp
12 |
13 |
14 | class CustomBatchNorm1d(nn.BatchNorm1d):
15 | def forward(self, input, update_running_stats: bool=True):
16 | self.track_running_stats = update_running_stats
17 | return super(CustomBatchNorm1d, self).forward(input)
18 |
19 |
20 | class MySimpleConv_MR_test(nn.Module):
21 | def __init__(self, in_feats: int, out_feats: int, e_types: list, drop_rate:float=0.0,
22 | mlp3_dim: int=64, bn_type: int=0):
23 | super(MySimpleConv_MR_test, self).__init__()
24 | self.e_types = e_types
25 | self.mlp3_dim = mlp3_dim
26 | self.bn_type = bn_type
27 | self.multi_relation = len(self.e_types) > 1
28 |
29 | self.proj_edges = nn.ModuleDict()
30 | for e_t in self.e_types:
31 | self.proj_edges[e_t] = build_mlp(in_feats * 2, out_feats, drop_rate, hid_dim=self.mlp3_dim)
32 |
33 | self.proj_out = CustomLinear(out_feats, out_feats, bias=True)
34 | if in_feats != out_feats:
35 | self.proj_skip = CustomLinear(in_feats, out_feats, bias=True)
36 | else:
37 | self.proj_skip = nn.Identity()
38 |
39 | if self.bn_type in [2, 3]:
40 | self.edge_bn = nn.ModuleDict()
41 | for e_t in self.e_types:
42 | self.edge_bn[e_t] = CustomBatchNorm1d(out_feats)
43 |
44 | def udf_edges(self, e_t: str):
45 | assert e_t in self.e_types, 'Invalid edge types!'
46 | tmp_fn = self.proj_edges[e_t]
47 |
48 | def fnc(edges):
49 | msg = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
50 | msg = tmp_fn(msg)
51 | return {'msg': msg}
52 | return fnc
53 |
54 | def forward(self, g, features, update_bn: bool=True):
55 | with g.local_scope():
56 | src_feats = dst_feats = features
57 | if g.is_block:
58 | dst_feats = src_feats[:g.num_dst_nodes()]
59 | g.srcdata['h'] = src_feats
60 | g.dstdata['h'] = dst_feats
61 |
62 | for e_t in g.etypes:
63 | g.apply_edges(self.udf_edges(e_t), etype=e_t)
64 |
65 | if self.bn_type in [2, 3]:
66 | if not self.multi_relation:
67 | g.edata['msg'] = self.edge_bn[self.e_types[0]](g.edata['msg'], update_running_stats=update_bn)
68 | else:
69 | for e_t in g.canonical_etypes:
70 | g.edata['msg'][e_t] = self.edge_bn[e_t[1]](g.edata['msg'][e_t], update_running_stats=update_bn)
71 |
72 | etype_dict = {}
73 | for e_t in g.etypes:
74 | etype_dict[e_t] = (fn.copy_e('msg', 'msg'), fn.sum('msg', 'out'))
75 | g.multi_update_all(etype_dict=etype_dict, cross_reducer='stack')
76 |
77 | out = g.dstdata.pop('out')
78 | out = torch.sum(out, dim=1)
79 | out = self.proj_out(out) + self.proj_skip(dst_feats)
80 |
81 | return out
82 |
83 |
84 | class simpleGNN_MR(nn.Module):
85 | def __init__(self, in_feats: int, hidden_feats: int, out_feats: int, num_layers: int, e_types: list,
86 | input_drop: float, hidden_drop: float, mlp_drop: float, mlp12_dim: int,
87 | mlp3_dim: int, bn_type: int):
88 | super(simpleGNN_MR, self).__init__()
89 | self.gnn_list = nn.ModuleList()
90 | self.bn_list = nn.ModuleList()
91 | self.num_layers = num_layers
92 | self.input_drop = input_drop
93 | self.hidden_drop = hidden_drop
94 | self.mlp_drop = mlp_drop
95 | self.mlp12_dim = mlp12_dim
96 | self.mlp3_dim = mlp3_dim
97 | self.bn_type = bn_type
98 |
99 | self.proj_in = build_mlp(in_feats, hidden_feats, self.mlp_drop, hid_dim=self.mlp12_dim)
100 | in_feats = hidden_feats
101 |
102 | self.in_bn = None
103 | if self.bn_type in [1, 3]:
104 | self.in_bn = CustomBatchNorm1d(hidden_feats)
105 |
106 | for i in range(num_layers):
107 | in_dim = in_feats if i==0 else hidden_feats
108 |
109 | self.gnn_list.append(
110 | MySimpleConv_MR_test(in_feats=in_dim, out_feats=hidden_feats,
111 | e_types=e_types, drop_rate=self.mlp_drop,
112 | mlp3_dim=self.mlp3_dim, bn_type=self.bn_type))
113 |
114 | self.bn_list.append(CustomBatchNorm1d(hidden_feats))
115 |
116 | self.proj_out = build_mlp(hidden_feats*(num_layers+1), out_feats, self.mlp_drop,
117 | hid_dim=self.mlp12_dim, final_act=False)
118 |
119 | self.dropout = nn.Dropout(p=self.hidden_drop)
120 | self.dropout_in = nn.Dropout(p=self.input_drop)
121 | self.activation = F.selu
122 |
123 | def forward(self, blocks: list, update_bn: bool=True, return_logits: bool=False):
124 | final_num = blocks[-1].num_dst_nodes()
125 | h = blocks[0].srcdata['feature']
126 | h = self.dropout_in(h)
127 |
128 | inter_results = []
129 | h = self.proj_in(h)
130 |
131 | if self.in_bn is not None:
132 | h = self.in_bn(h, update_running_stats=update_bn)
133 |
134 | inter_results.append(h[:final_num])
135 | for block, gnn, bn in zip(blocks, self.gnn_list, self.bn_list):
136 | h = gnn(block, h, update_bn)
137 | h = bn(h, update_running_stats=update_bn)
138 | h = self.activation(h)
139 | h = self.dropout(h)
140 |
141 | inter_results.append(h[:final_num])
142 |
143 | if return_logits:
144 | return inter_results
145 | else:
146 | h = torch.stack(inter_results, dim=1)
147 | h = h.reshape(h.shape[0], -1)
148 | h = self.proj_out(h)
149 | return h.log_softmax(dim=-1)
150 |
151 |
--------------------------------------------------------------------------------
/modules/aux_mod.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dgl
3 | import dgl.function as fn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import os
8 | import time
9 | import csv
10 | import math
11 | import modules.mod_utls as m_utls
12 | import random
13 | from modules.conv_mod import CustomLinear
14 | from modules.mr_conv_mod import build_mlp
15 |
16 |
17 | Tensor = torch.tensor
18 |
19 |
20 | def fixed_augmentation(graph, seed_nodes, sampler, aug_type: str, p: float=None):
21 | assert aug_type in ['dropout', 'dropnode', 'dropedge', 'replace', 'drophidden', 'none']
22 | with graph.local_scope():
23 | if aug_type == 'dropout':
24 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes)
25 | blocks[0].srcdata['feature'] = F.dropout(blocks[0].srcdata['feature'], p)
26 |
27 | elif aug_type == 'dropnode':
28 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes)
29 | blocks[0].srcdata['feature'] = m_utls.drop_node(blocks[0].srcdata['feature'], p)
30 |
31 | elif aug_type == 'dropedge':
32 | del_edges = {}
33 | for et in graph.etypes:
34 | _, _, eid = graph.in_edges(seed_nodes, etype=et, form='all')
35 | num_remove = math.floor(eid.shape[0] * p)
36 | del_edges[et] = eid[torch.randperm(eid.shape[0])][:num_remove]
37 | aug_graph = graph
38 | for et in del_edges.keys():
39 | aug_graph = dgl.remove_edges(aug_graph, del_edges[et], etype=et)
40 | input_nodes, output_nodes, blocks = sampler.sample_blocks(aug_graph, seed_nodes)
41 |
42 | elif aug_type == 'replace':
43 | raise Exception("The Replace sample is not implemented!")
44 |
45 | elif aug_type == 'drophidden':
46 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes)
47 |
48 | else:
49 | input_nodes, output_nodes, blocks = sampler.sample_blocks(graph, seed_nodes)
50 |
51 | return input_nodes, output_nodes, blocks
52 |
53 |
--------------------------------------------------------------------------------
/modules/conv_mod.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dgl
3 | import math
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import dgl.function as fn
7 | from dgl.nn.pytorch import TypedLinear
8 |
9 |
10 | class CustomLinear(nn.Linear):
11 | def reset_parameters(self):
12 | nn.init.xavier_normal_(self.weight)
13 | nn.init.zeros_(self.bias)
14 |
--------------------------------------------------------------------------------
/modules/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dgl
3 | import os
4 | import numpy as np
5 | from dgl.data.utils import load_graphs
6 | from torch.utils.data import DataLoader as torch_dataloader
7 | from dgl.dataloading import DataLoader
8 | from sklearn.model_selection import train_test_split
9 | import torch.nn.functional as F
10 | import dgl.function as fn
11 | import logging
12 | import pickle
13 | import os
14 |
15 |
16 | def get_dataset(name: str, raw_dir: str, to_homo: bool=False, random_state: int=717):
17 | if name == 'yelp':
18 | yelp_data = dgl.data.FraudYelpDataset(raw_dir=raw_dir, random_seed=7537, verbose=False)
19 | graph = yelp_data[0]
20 | if to_homo:
21 | graph = dgl.to_homogeneous(graph, ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])
22 | graph = dgl.add_self_loop(graph)
23 |
24 | elif name == 'amazon':
25 | amazon_data = dgl.data.FraudAmazonDataset(raw_dir=raw_dir, random_seed=7537, verbose=False)
26 | graph = amazon_data[0]
27 | if to_homo:
28 | graph = dgl.to_homogeneous(graph, ndata=['feature', 'label', 'train_mask', 'val_mask', 'test_mask'])
29 | graph = dgl.add_self_loop(graph)
30 |
31 | elif name == 'tsocial':
32 | t_social, _ = load_graphs(os.path.join(raw_dir, 'tsocial'))
33 | graph = t_social[0]
34 | graph.ndata['feature'] = graph.ndata['feature'].float()
35 |
36 | elif name == 'tfinance':
37 | t_finance, _ = load_graphs(os.path.join(raw_dir, 'tfinance'))
38 | graph = t_finance[0]
39 | graph.ndata['label'] = graph.ndata['label'].argmax(1)
40 | graph.ndata['feature'] = graph.ndata['feature'].float()
41 |
42 | else:
43 | raise
44 |
45 | return graph
46 |
47 |
48 | def get_index_loader_test(name: str, batch_size: int, unlabel_ratio: int=1, training_ratio: float=-1,
49 | shuffle_train: bool=True, to_homo:bool=False):
50 | assert name in ['yelp', 'amazon', 'tfinance', 'tsocial'], 'Invalid dataset name'
51 |
52 | graph = get_dataset(name, 'data/', to_homo=to_homo, random_state=7537)
53 |
54 | index = np.arange(graph.num_nodes())
55 | labels = graph.ndata['label']
56 | if name == 'amazon':
57 | index = np.arange(3305, graph.num_nodes())
58 |
59 | train_nids, valid_test_nids = train_test_split(index, stratify=labels[index],
60 | train_size=training_ratio/100., random_state=2, shuffle=True)
61 | valid_nids, test_nids = train_test_split(valid_test_nids, stratify=labels[valid_test_nids],
62 | test_size=0.67, random_state=2, shuffle=True)
63 |
64 | train_mask = torch.zeros_like(labels).bool()
65 | val_mask = torch.zeros_like(labels).bool()
66 | test_mask = torch.zeros_like(labels).bool()
67 |
68 | train_mask[train_nids] = 1
69 | val_mask[valid_nids] = 1
70 | test_mask[test_nids] = 1
71 |
72 | graph.ndata['train_mask'] = train_mask
73 | graph.ndata['val_mask'] = val_mask
74 | graph.ndata['test_mask'] = test_mask
75 |
76 | labeled_nids = train_nids
77 | unlabeled_nids = np.concatenate([valid_nids, test_nids, train_nids])
78 |
79 | power = 10 if name == 'tfinance' else 16
80 |
81 | valid_loader = torch_dataloader(valid_nids, batch_size=2**power, shuffle=False, drop_last=False, num_workers=4)
82 | test_loader = torch_dataloader(test_nids, batch_size=2**power, shuffle=False, drop_last=False, num_workers=4)
83 | labeled_loader = torch_dataloader(labeled_nids, batch_size=batch_size, shuffle=shuffle_train, drop_last=True, num_workers=0)
84 | unlabeled_loader = torch_dataloader(unlabeled_nids, batch_size=batch_size * unlabel_ratio, shuffle=shuffle_train, drop_last=True, num_workers=0)
85 |
86 | return graph, labeled_loader, valid_loader, test_loader, unlabeled_loader
87 |
88 |
--------------------------------------------------------------------------------
/modules/evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, precision_score, confusion_matrix, average_precision_score
3 | from scikitplot.helpers import binary_ks_curve
4 | import torch
5 |
6 |
7 | Tensor = torch.tensor
8 |
9 |
10 | def eval_auc_roc(pred, target):
11 | scores = roc_auc_score(target, pred)
12 | return scores
13 |
14 |
15 | def eval_auc_pr(pred, target):
16 | scores = average_precision_score(target, pred)
17 | return scores
18 |
19 |
20 | def eval_ks_statistics(target, pred):
21 | scores = binary_ks_curve(target, pred)[3]
22 | return scores
23 |
24 |
25 | def find_best_f1(probs, labels):
26 | best_f1, best_thre = -1., -1.
27 | thres_arr = np.linspace(0.05, 0.95, 19)
28 | for thres in thres_arr:
29 | preds = np.zeros_like(labels)
30 | preds[probs > thres] = 1
31 | mf1 = f1_score(labels, preds, average='macro')
32 | if mf1 > best_f1:
33 | best_f1 = mf1
34 | best_thre = thres
35 | return best_f1, best_thre
36 |
37 |
38 | def eval_pred(pred: Tensor, target: Tensor):
39 | s_pred = pred.cpu().detach().numpy()
40 | s_target = target.cpu().detach().numpy()
41 |
42 | auc_roc = roc_auc_score(s_target, s_pred)
43 | auc_pr = average_precision_score(s_target, s_pred)
44 | ks_statistics = eval_ks_statistics(s_target, s_pred)
45 |
46 | best_f1, best_thre = find_best_f1(s_pred, s_target)
47 | p_labels = (s_pred > best_thre).astype(int)
48 | accuracy = np.mean(s_target == p_labels)
49 | recall = recall_score(s_target, p_labels)
50 | precision = precision_score(s_target, p_labels)
51 |
52 | return auc_roc, auc_pr, ks_statistics, accuracy, \
53 | recall, precision, best_f1, best_thre
54 |
55 |
--------------------------------------------------------------------------------
/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import modules.mod_utls as m_utls
4 |
5 |
6 | Tensor = torch.tensor
7 |
8 |
9 | def nll_loss(pred, target, pos_w: float=1.0):
10 | weight_tensor = torch.tensor([1., pos_w]).to(pred.device)
11 | loss_value = F.nll_loss(pred, target.long(), weight=weight_tensor)
12 |
13 | return loss_value, m_utls.to_np(loss_value)
14 |
15 |
16 | def nll_loss_raw(pred: Tensor, target: Tensor, pos_w,
17 | reduction: str='mean'):
18 | weight_tensor = torch.tensor([1., pos_w]).to(pred.device)
19 | loss_value = F.nll_loss(pred, target.long(), weight=weight_tensor,
20 | reduction=reduction)
21 |
22 | return loss_value
23 |
24 |
25 | def l2_regularization(model):
26 | l2_reg = torch.tensor(0., requires_grad=True)
27 | for key, value in model.named_parameters():
28 | if len(value.shape) > 1 and 'weight' in key:
29 | l2_reg = l2_reg + torch.sum(value ** 2) * 0.5
30 | return l2_reg
31 |
32 |
33 |
--------------------------------------------------------------------------------
/modules/mod_utls.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import dgl
5 | import os
6 | import math
7 | import pickle
8 | from sklearn.metrics import f1_score
9 |
10 |
11 | def to_np(x):
12 | return x.cpu().detach().numpy()
13 |
14 |
15 | def store_model(my_model, args):
16 | file_path = os.path.join('model-weights',
17 | args['data-set'] + '.pth')
18 | torch.save(my_model.state_dict(), file_path)
19 |
--------------------------------------------------------------------------------
/modules/mr_conv_mod.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dgl
3 | import math
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import dgl.function as fn
7 | from dgl.nn.pytorch import TypedLinear
8 | from modules.conv_mod import CustomLinear
9 |
10 |
11 | def build_mlp(in_dim: int, out_dim: int, p: float, hid_dim: int=64, final_act: bool=True):
12 | mlp_list = []
13 |
14 | mlp_list.append(CustomLinear(in_dim, hid_dim, bias=True))
15 | mlp_list.append(nn.ELU())
16 | mlp_list.append(nn.Dropout(p=p))
17 | mlp_list.append(nn.LayerNorm(hid_dim))
18 | mlp_list.append(CustomLinear(hid_dim, out_dim, bias=True))
19 | if final_act:
20 | mlp_list.append(nn.ELU())
21 | mlp_list.append(nn.Dropout(p=p))
22 |
23 | return nn.Sequential(*mlp_list)
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | _libgcc_mutex==0.1
2 | _openmp_mutex==5.1
3 | abseil-cpp==20211102.0
4 | absl-py==1.3.0
5 | aiohttp==3.8.3
6 | aiosignal==1.2.0
7 | anyio==3.6.2
8 | argon2-cffi==21.3.0
9 | argon2-cffi-bindings==21.2.0
10 | arrow-cpp==8.0.0
11 | async-timeout==4.0.2
12 | asynctest==0.13.0
13 | attrs==22.2.0
14 | aws-c-common==0.4.57
15 | aws-c-event-stream==0.1.6
16 | aws-checksums==0.1.9
17 | aws-sdk-cpp==1.8.185
18 | babel==2.11.0
19 | backcall==0.2.0
20 | backports==1.0
21 | backports.functools_lru_cache==1.6.4
22 | beautifulsoup4==4.11.1
23 | blas==1.0
24 | bleach==5.0.1
25 | blinker==1.4
26 | boost-cpp==1.73.0
27 | bottleneck==1.3.5
28 | brotli==1.0.9
29 | brotli-bin==1.0.9
30 | brotlipy==0.7.0
31 | bzip2==1.0.8
32 | c-ares==1.19.0
33 | ca-certificates==2023.05.30
34 | cachetools==5.3.0
35 | certifi==2022.12.7
36 | cffi==1.15.1
37 | charset-normalizer==2.0.4
38 | click==8.1.3
39 | cryptography==38.0.1
40 | cuda==11.7.1
41 | cuda-cccl==11.7.91
42 | cuda-command-line-tools==11.7.1
43 | cuda-compiler==11.7.1
44 | cuda-cudart==11.7.99
45 | cuda-cudart-dev==11.7.99
46 | cuda-cuobjdump==11.7.91
47 | cuda-cupti==11.7.101
48 | cuda-cuxxfilt==11.7.91
49 | cuda-demo-suite==12.0.76
50 | cuda-documentation==12.0.76
51 | cuda-driver-dev==11.7.99
52 | cuda-gdb==12.0.90
53 | cuda-libraries==11.7.1
54 | cuda-libraries-dev==11.7.1
55 | cuda-memcheck==11.8.86
56 | cuda-nsight==12.0.78
57 | cuda-nsight-compute==12.0.0
58 | cuda-nvcc==11.7.99
59 | cuda-nvdisasm==12.0.76
60 | cuda-nvml-dev==11.7.91
61 | cuda-nvprof==12.0.90
62 | cuda-nvprune==11.7.91
63 | cuda-nvrtc==11.7.99
64 | cuda-nvrtc-dev==11.7.99
65 | cuda-nvtx==11.7.91
66 | cuda-nvvp==12.0.90
67 | cuda-runtime==11.7.1
68 | cuda-sanitizer-api==12.0.90
69 | cuda-toolkit==11.7.1
70 | cuda-tools==11.7.1
71 | cuda-visual-tools==11.7.1
72 | cycler==0.11.0
73 | dbus==1.13.18
74 | decorator==5.1.1
75 | defusedxml==0.7.1
76 | dgl==1.1.0.cu118
77 | docker-pycreds==0.4.0
78 | entrypoints==0.4
79 | expat==2.4.9
80 | ffmpeg==4.3
81 | fftw==3.3.9
82 | flit-core==3.6.0
83 | fontconfig==2.14.1
84 | fonttools==4.25.0
85 | freetype==2.12.1
86 | frozenlist==1.3.3
87 | gdb==11.2
88 | gds-tools==1.5.0.59
89 | gflags==2.2.2
90 | giflib==5.2.1
91 | gitdb==4.0.10
92 | gitpython==3.1.29
93 | glib==2.69.1
94 | glog==0.5.0
95 | gmp==6.2.1
96 | gmpy2==2.1.2
97 | gnutls==3.6.15
98 | google-api-core==2.11.0
99 | google-api-python-client==2.83.0
100 | google-auth==2.17.1
101 | google-auth-httplib2==0.1.0
102 | google-auth-oauthlib==1.0.0
103 | googleapis-common-protos==1.59.0
104 | grpc-cpp==1.46.1
105 | grpcio==1.42.0
106 | gst-plugins-base==1.14.0
107 | gstreamer==1.14.0
108 | httplib2==0.22.0
109 | icu==58.2
110 | idna==3.4
111 | importlib-metadata==4.11.4
112 | importlib_resources==5.10.1
113 | intel-openmp==2021.4.0
114 | ipykernel==5.5.5
115 | ipython==7.33.0
116 | ipython_genutils==0.2.0
117 | jedi==0.18.2
118 | jinja2==3.1.2
119 | joblib==1.1.1
120 | jpeg==9e
121 | json5==0.9.5
122 | jsonschema==4.17.3
123 | jupyter_client==7.0.6
124 | jupyter_core==4.11.2
125 | jupyter_server==1.23.4
126 | jupyterlab==3.5.2
127 | jupyterlab_pygments==0.2.2
128 | jupyterlab_server==2.17.0
129 | kiwisolver==1.4.4
130 | krb5==1.19.2
131 | lame==3.100
132 | lcms2==2.12
133 | ld_impl_linux-64==2.38
134 | lerc==3.0
135 | libboost==1.73.0
136 | libbrotlicommon==1.0.9
137 | libbrotlidec==1.0.9
138 | libbrotlienc==1.0.9
139 | libclang==10.0.1
140 | libcublas==11.10.3.66
141 | libcublas-dev==11.10.3.66
142 | libcufft==10.7.2.124
143 | libcufft-dev==10.7.2.124
144 | libcufile==1.5.0.59
145 | libcufile-dev==1.5.0.59
146 | libcurand==10.3.1.50
147 | libcurand-dev==10.3.1.50
148 | libcurl==7.87.0
149 | libcusolver==11.4.0.1
150 | libcusolver-dev==11.4.0.1
151 | libcusparse==11.7.4.91
152 | libcusparse-dev==11.7.4.91
153 | libdeflate==1.8
154 | libedit==3.1.20221030
155 | libev==4.33
156 | libevent==2.1.12
157 | libffi==3.4.2
158 | libgcc-ng==11.2.0
159 | libgfortran-ng==11.2.0
160 | libgfortran5==11.2.0
161 | libgomp==11.2.0
162 | libiconv==1.16
163 | libidn2==2.3.2
164 | libllvm10==10.0.1
165 | libnghttp2==1.52.0
166 | libnpp==11.7.4.75
167 | libnpp-dev==11.7.4.75
168 | libnvjpeg==11.8.0.2
169 | libnvjpeg-dev==11.8.0.2
170 | libpng==1.6.37
171 | libpq==12.9
172 | libprotobuf==3.20.3
173 | libsodium==1.0.18
174 | libssh2==1.10.0
175 | libstdcxx-ng==11.2.0
176 | libtasn1==4.16.0
177 | libthrift==0.15.0
178 | libtiff==4.4.0
179 | libunistring==0.9.10
180 | libuuid==1.41.5
181 | libwebp==1.2.4
182 | libwebp-base==1.2.4
183 | libxcb==1.15
184 | libxkbcommon==1.0.1
185 | libxml2==2.9.14
186 | libxslt==1.1.35
187 | lz4-c==1.9.4
188 | markdown==3.4.1
189 | markupsafe==2.1.1
190 | matplotlib==3.5.2
191 | matplotlib-base==3.5.2
192 | matplotlib-inline==0.1.6
193 | mistune==2.0.4
194 | mkl==2021.4.0
195 | mkl-service==2.4.0
196 | mkl_fft==1.3.1
197 | mkl_random==1.2.2
198 | mpc==1.1.0
199 | mpfr==4.0.2
200 | mpmath==1.2.1
201 | multidict==6.0.2
202 | munkres==1.1.4
203 | nbclassic==0.4.8
204 | nbclient==0.6.8
205 | nbconvert==7.2.7
206 | nbconvert-core==7.2.7
207 | nbconvert-pandoc==7.2.7
208 | nbformat==5.7.1
209 | ncurses==6.3
210 | nest-asyncio==1.5.6
211 | nettle==3.7.3
212 | networkx==2.2
213 | notebook==6.5.2
214 | notebook-shim==0.2.2
215 | nsight-compute==2022.4.0.15
216 | nspr==4.33
217 | nss==3.74
218 | numexpr==2.8.4
219 | numpy==1.21.5
220 | numpy-base==1.21.5
221 | oauthlib==3.2.2
222 | openh264==2.1.1
223 | openssl==1.1.1u
224 | opt-einsum==3.3.0
225 | orc==1.7.4
226 | packaging==22.0
227 | pandas==1.3.5
228 | pandoc==2.19.2
229 | pandocfilters==1.5.0
230 | parso==0.8.3
231 | pathtools==0.1.2
232 | pcre==8.45
233 | pexpect==4.8.0
234 | pickleshare==0.7.5
235 | pillow==9.3.0
236 | pip==22.3.1
237 | pkgutil-resolve-name==1.3.10
238 | ply==3.11
239 | progress==1.5
240 | prometheus_client==0.15.0
241 | promise==2.3
242 | prompt-toolkit==3.0.36
243 | protobuf==4.21.12
244 | psutil==5.9.0
245 | ptyprocess==0.7.0
246 | pyarrow==8.0.0
247 | pyasn1==0.4.8
248 | pyasn1-modules==0.2.8
249 | pycparser==2.21
250 | pyg==2.2.0
251 | pygments==2.13.0
252 | pyjwt==2.4.0
253 | pyopenssl==22.0.0
254 | pyparsing==3.0.9
255 | pyqt==5.15.7
256 | pyqt5-sip==12.11.0
257 | pyro-api==0.1.2
258 | pyro-ppl==1.8.4
259 | pyrsistent==0.18.0
260 | pysocks==1.7.1
261 | python==3.7.15
262 | python-dateutil==2.8.2
263 | python-fastjsonschema==2.16.2
264 | python_abi==3.7
265 | pytorch==1.13.1
266 | pytorch-cluster==1.6.0
267 | pytorch-cuda==11.7
268 | pytorch-mutex==1.0
269 | pytorch-scatter==2.1.0
270 | pytorch-sparse==0.6.16
271 | pytz==2022.7
272 | pyyaml==6.0
273 | pyzmq==19.0.2
274 | qt-main==5.15.2
275 | qt-webengine==5.15.9
276 | qtwebkit==5.212
277 | re2==2022.04.01
278 | readline==8.2
279 | requests==2.28.1
280 | requests-oauthlib==1.3.1
281 | rsa==4.9
282 | scikit-learn==1.0.2
283 | scikit-plot==0.3.7
284 | scipy==1.7.3
285 | seaborn==0.12.2
286 | send2trash==1.8.0
287 | sentry-sdk==1.12.1
288 | setproctitle==1.3.2
289 | setuptools==65.5.0
290 | shortuuid==1.0.11
291 | sip==6.6.2
292 | six==1.16.0
293 | smmap==5.0.0
294 | snappy==1.1.9
295 | sniffio==1.3.0
296 | soupsieve==2.3.2.post1
297 | sqlite==3.40.0
298 | sympy==1.10.1
299 | tensorboard==2.6.0
300 | tensorboard-data-server==0.6.1
301 | tensorboard-plugin-wit==1.8.1
302 | tensorboardx==2.2
303 | terminado==0.17.1
304 | threadpoolctl==2.2.0
305 | tinycss2==1.2.1
306 | tk==8.6.12
307 | toml==0.10.2
308 | tomli==2.0.1
309 | torchaudio==0.13.1
310 | torchvision==0.14.1
311 | tornado==6.1
312 | tqdm==4.64.1
313 | traitlets==5.8.0
314 | typing-extensions==4.4.0
315 | typing_extensions==4.4.0
316 | uritemplate==4.1.1
317 | urllib3==1.26.13
318 | utf8proc==2.6.1
319 | wandb==0.13.7
320 | wcwidth==0.2.5
321 | webencodings==0.5.1
322 | websocket-client==1.4.2
323 | werkzeug==2.2.2
324 | wheel==0.37.1
325 | xz==5.2.8
326 | yarl==1.8.1
327 | zeromq==4.3.4
328 | zipp==3.11.0
329 | zlib==1.2.13
330 | zstd==1.5.2
331 |
--------------------------------------------------------------------------------