├── .gitignore
├── LICENSE
├── README.md
├── dataloaders.py
├── environment.yml
├── log.py
├── main.py
├── main_hetro.py
├── main_horder.py
├── models
├── layer.py
├── model.py
└── model_horder.py
├── subg_acc
├── .gitignore
├── LICENSE
├── README.md
├── graph_acc.c
├── setup.py
└── uthash.h
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # custom
2 | .DS_Store
3 | .idea/*
4 | __pycache__/*
5 | notebooks/.ipynb_checkpoints/*
6 | log/*
7 | models/__pycacche__/*
8 | *pyc
9 |
10 | # Byte-complied / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 | # pytype static type analyzer
144 | .pytype/
145 |
146 | # Cython debug symbols
147 | cython_debug/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2022-Present, Haoteng Yin and GCoM@Purdue
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
SUREL: Subgraph-based Graph Representation Learning Framework
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | SUREL is a novel walk-based computation framework for efficient large-scale subgraph-base graph representation learning (SGRL). Details on how SUREL works can be found in our VLDB'22 paper [Algorithm and System Co-design for Efficient Subgraph-based Graph Representation Learning](https://arxiv.org/pdf/2202.13538.pdf).
11 |
12 | Currently, we support:
13 | - Large-scale graph ML tasks: link prediction / relation-type prediction / higher-order pattern prediction
14 | - Preprocessing and training of datasets in OGB format
15 | - Python API ([SubG Library](https://github.com/VeritasYin/subg_acc)) for subgraph sampling and joining procedures
16 | - Single GPU training and evaluation
17 | - Structural (Relative Position) Encoding + Node Features
18 | - [VesselGraph](https://paperswithcode.com/dataset/vesselgraph) Dataset
19 |
20 | We are working on expanding the functionality of SUREL to include:
21 | - Multi-GPU training
22 |
23 | ## Requirements ##
24 | (Other versions may work, but are untested)
25 | * Ubuntu 20.04
26 | * CUDA >= 10.2
27 | * python >= 3.8
28 | * 1.8 <= pytorch <= 1.12
29 |
30 | ## Datasets
31 |
32 | SGRL datasets (`mag-write (P-A)`, `mag-cite (P-P)`, `tags-math`, `DBLP-coauthor`) for relation and higher-order prediction can be accessed via [Zenodo](https://zenodo.org/records/15186012).
33 |
34 | ## SGRL Environment Setup ##
35 |
36 | Requirements: Python >= 3.8, [Anaconda3](https://www.anaconda.com/)
37 |
38 | - Update conda:
39 | ```bash
40 | conda update -n base -c defaults conda
41 | ```
42 |
43 | - Install basic dependencies to virtual environment and activate it:
44 | ```bash
45 | conda env create -f environment.yml
46 | conda activate sgrl-env
47 | ```
48 | - **SUREL** now support PyTorch 1.12.1 and PyG 2.2.0. To install them, simply run
49 | ```bash
50 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
51 | conda install pyg -c pyg
52 | ```
53 | For more details, please refer to the [PyTorch](https://pytorch.org/) and [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html).
54 | The code of this repository is lately tested with Python 3.10.9 + PyTorch 1.12.1 (CUDA 11.3) + torch-geometric 2.2.0.
55 |
56 | - Example commends of installation for PyTorch 1.8.0 (CUDA 10.2) and torch-geometric 1.6.3:
57 | ```bash
58 | conda install pytorch==1.8.0 torchvision torchaudio cudatoolkit=10.2
59 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
60 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
61 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
62 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
63 | pip install torch-geometric==1.6.3
64 | ```
65 |
66 | ## Quick Start
67 |
68 | 1. Install required version of PyTorch that is compatible with your CUDA driver
69 |
70 | 2. Clone the repository `git clone https://github.com/Graph-COM/SUREL`
71 |
72 | 3. Build and install the [SubG](https://github.com/Graph-COM/SUREL/tree/main/subg_acc) library (v1.1) `cd subg_acc;python3 setup.py install`
73 |
74 | - To train **SUREL** for link prediction on Collab:
75 | ```bash
76 | python main.py --dataset ogbl-collab --metric hit --num_step 4 --num_walk 200 --use_val
77 | ```
78 |
79 | - To train **SUREL** for link prediction on Citation2:
80 | ```bash
81 | python main.py --dataset ogbl-citation2 --metric mrr --num_step 4 --num_walk 100
82 | ```
83 |
84 | - To train **SUREL** for relation prediction on MAG(A-P):
85 | ```bash
86 | python main_hetro.py --dataset mag --relation write --metric mrr --num_step 3 --num_walk 100 --k 10
87 | ```
88 |
89 | - To train **SUREL** for higher-order prediction on DBLP:
90 | ```bash
91 | python main_horder.py --dataset DBLP-coauthor --metric mrr --num_step 3 --num_walk 100
92 | ```
93 |
94 | - All detailed training logs can be found at `//.log`.
95 |
96 | ## Result Reproduction
97 | This section supplements our SUREL paper accepted in VLDB'22. To reproduce the results of SUREL reported in Tables 3 and 4, use the following command:
98 | * OGBL - Link Prediction
99 | ```bash
100 | python3 main.py --dataset --metric --num_step --num_walk --k
101 | ```
102 | where `dataset` can be either of `ogbl-citation2`, `ogbl-collab` and `ogbl-ppa`; `metric` can be either `mrr` or `hit`.
103 | * Relation Type Prediction
104 | ```bash
105 | python main_hetro.py --dataset mag --relation --metric mrr --num_step --num_walk --k
106 | ```
107 | where `relation` can be either `write` or `cite`.
108 | * Higher-order Pattern Prediction
109 | ```bash
110 | python main_horder.py --dataset --metric mrr --num_step --num_walk --k
111 | ```
112 | where `dataset` can be either `DBLP-coauthor` or `tags-math`.
113 |
114 | The detailed parameter configurations are provided in Table 8, Appendix D of the [arxiv version](https://arxiv.org/abs/2202.13538) of this work. For the profiling of SUREL in Table 4 and Fig. 4 (a-b), please use the parameter setting provided in Appendix D.3.
115 |
116 | To test the scaling performance of Walk Sampler and RPE Joining, functions 'run_walk' and 'sjoin' can be imported and called from the module `surel_gacc`. Please adjust the parameter values of `num_walk`, `num_step` and `nthread` accordingly as Fig. 4 (c-d) shown.
117 |
118 | To perform hyper-parameter analysis of the number of walks 𝑀, the step of walks 𝑚, and the hidden dimension 𝑑, please adjust the parameter values of `num_walk`, `num_step` and `hidden_dim` accordingly as Fig. 5 shown.
119 |
120 |
121 | Sample Output
122 |
123 | ```text
124 | 2022-03-25 15:57:16,677 - root - INFO - Create log file at ./log/ogbl-citation2/032522_155716.log
125 | 2022-03-25 15:57:16,677 - root - INFO - Command line executed: python main.py --gpu 2 --patience 5 --hidden_dim 64 --seed 0
126 | 2022-03-25 15:57:16,677 - root - INFO - Full args parsed:
127 | 2022-03-25 15:57:16,677 - root - INFO - Namespace(B_size=1500, batch_num=2000, batch_size=32, data_usage=1.0, dataset='ogbl-citation2', debug=False, directed=False, dropout=0.1, eval_steps=100, gpu_id=2, hidden_dim=64, k=50, l2=0.0, layers=2, load_dict=False, load_model=False, log_dir='./log/', lr=0.001, memo=None, metric='mrr', model='RNN', norm='all', nthread=16, num_step=4, num_walk=100, optim='adam', patience=5, repeat=1, res_dir='./dataset/save', rtest=499, save=False, seed=0, stamp='032522_155716', summary_file='result_summary.log', test_ratio=1.0, train_ratio=0.05, use_degree=False, use_feature=False, use_htype=False, use_val=False, use_weight=False, valid_ratio=0.1, x_dim=0)
128 | 2022-03-25 15:57:16,727 - root - INFO - torch num_threads 16
129 | 2022-03-25 15:57:26,536 - root - INFO - eval metric mrr
130 | task type link prediction
131 | download_name citation-v2
132 | version 1
133 | url http://snap.stanford.edu/ogb/data/linkproppred...
134 | add_inverse_edge False
135 | has_node_attr True
136 | has_edge_attr False
137 | split time
138 | additional node files node_year
139 | additional edge files None
140 | is hetero False
141 | binary False
142 | Name: ogbl-citation2, dtype: object
143 | Keys: ['x', 'edge_index', 'node_year']
144 | 2022-03-25 15:57:26,536 - root - INFO - node size 2927963, feature dim 128, edge size 30387995 with mask ratio 0.05
145 | 2022-03-25 15:57:26,536 - root - INFO - use_weight False, use_coalesce False, use_degree False, use_val False
146 | 2022-03-25 15:57:45,775 - root - INFO - Sparsity of loaded graph 6.727197221716796e-06
147 | 2022-03-25 15:57:45,782 - root - INFO - Observed subgraph with 2918932 nodes and 28836021 edges;
148 | 2022-03-25 15:57:45,789 - root - INFO - Training subgraph with 1394162 nodes and 1519315 edges.
149 | 2022-03-25 15:57:50,400 - root - INFO - #Model Params 79617
150 | 2022-03-25 15:59:14,643 - root - INFO - Samples: valid 8659 by 1000 test 86596 by 1000 metric: mrr
151 | 2022-03-25 15:59:15,405 - root - INFO - Running Round 1
152 | 2022-03-25 15:59:29,229 - root - INFO - Batch 1 W1502/D1394162 Loss: 0.1971, AUC: 0.5049
153 | 2022-03-25 15:59:42,266 - root - INFO - Batch 2 W2991/D1394162 Loss: 0.1097, AUC: 0.4975
154 | 2022-03-25 15:59:56,187 - root - INFO - Batch 3 W4431/D1394162 Loss: 0.1024, AUC: 0.4976
155 | 2022-03-25 16:00:09,070 - root - INFO - Batch 4 W5761/D1394162 Loss: 0.1030, AUC: 0.4980
156 | 2022-03-25 16:00:23,285 - root - INFO - Batch 5 W7215/D1394162 Loss: 0.1013, AUC: 0.5053
157 | ...
158 | ```
159 |
160 |
161 | ## Usage
162 | ```
163 | usage: Interface for SUREL framework [-h]
164 | [--dataset {ogbl-ppa,ogbl-citation2,ogbl-collab,mag,DBLP-coauthor,tags-math}]
165 | [--model {RNN,MLP,Transformer,GNN}]
166 | [--layers LAYERS]
167 | [--hidden_dim HIDDEN_DIM] [--x_dim X_DIM]
168 | [--data_usage DATA_USAGE]
169 | [--train_ratio TRAIN_RATIO]
170 | [--valid_ratio VALID_RATIO]
171 | [--test_ratio TEST_RATIO]
172 | [--metric {auc,mrr,hit}] [--seed SEED]
173 | [--gpu_id GPU_ID] [--nthread NTHREAD]
174 | [--B_size B_SIZE] [--num_walk NUM_WALK]
175 | [--num_step NUM_STEP] [--k K]
176 | [--directed DIRECTED] [--use_feature]
177 | [--use_weight] [--use_degree]
178 | [--use_htype] [--use_val] [--norm NORM]
179 | [--optim OPTIM] [--rtest RTEST]
180 | [--eval_steps EVAL_STEPS]
181 | [--batch_size BATCH_SIZE]
182 | [--batch_num BATCH_NUM] [--lr LR]
183 | [--dropout DROPOUT] [--l2 L2]
184 | [--patience PATIENCE] [--repeat REPEAT]
185 | [--log_dir LOG_DIR] [--res_dir RES_DIR]
186 | [--stamp STAMP]
187 | [--summary_file SUMMARY_FILE] [--debug]
188 | [--abs] [--save] [--load_dict]
189 | [--load_model] [--memo MEMO]
190 | ```
191 |
192 |
193 | Optional Arguments
194 |
195 | ```
196 | optional arguments:
197 | -h, --help show this help message and exit
198 | --dataset {mag} dataset name
199 | --relation {write,cite}
200 | relation type
201 | --model {RNN,MLP,Transformer,GNN}
202 | base model to use
203 | --layers LAYERS number of layers
204 | --hidden_dim HIDDEN_DIM
205 | hidden dimension
206 | --x_dim X_DIM dim of raw node features
207 | --data_usage DATA_USAGE
208 | use partial dataset
209 | --train_ratio TRAIN_RATIO
210 | mask partial edges for training
211 | --valid_ratio VALID_RATIO
212 | use partial valid set
213 | --test_ratio TEST_RATIO
214 | use partial test set
215 | --metric {auc,mrr,hit}
216 | metric for evaluating performance
217 | --seed SEED seed to initialize all the random modules
218 | --gpu_id GPU_ID gpu id
219 | --nthread NTHREAD number of thread
220 | --B_size B_SIZE set size of train sampling
221 | --num_walk NUM_WALK total number of random walks
222 | --num_step NUM_STEP total steps of random walk
223 | --k K number of paired negative queries
224 | --directed DIRECTED whether to treat the graph as directed
225 | --use_feature whether to use raw features as input
226 | --use_weight whether to use edge weight as input
227 | --use_degree whether to use node degree as input
228 | --use_htype whether to use node type as input
229 | --use_val whether to use val as input
230 | --norm NORM method of normalization
231 | --optim OPTIM optimizer to use
232 | --rtest RTEST step start to test
233 | --eval_steps EVAL_STEPS
234 | number of steps to test
235 | --batch_size BATCH_SIZE
236 | mini-batch size (train)
237 | --batch_num BATCH_NUM
238 | mini-batch size (test)
239 | --lr LR learning rate
240 | --dropout DROPOUT dropout rate
241 | --l2 L2 l2 regularization (weight decay)
242 | --patience PATIENCE early stopping steps
243 | --repeat REPEAT number of training instances to repeat
244 | --log_dir LOG_DIR log directory
245 | --res_dir RES_DIR resource directory
246 | --stamp STAMP time stamp
247 | --summary_file SUMMARY_FILE
248 | brief summary of training results
249 | --debug whether to use debug mode
250 | --save whether to save RPE to files
251 | --load_dict whether to load RPE from files
252 | --load_model whether to load saved model from files
253 | --memo MEMO notes
254 | ```
255 |
256 |
257 | ## Citation
258 | Please cite our paper if you are interested in our work.
259 | ```
260 | @article{yin2022algorithm,
261 | title={Algorithm and System Co-design for Efficient Subgraph-based Graph Representation Learning},
262 | author={Yin, Haoteng and Zhang, Muhan and Wang, Yanbang and Wang, Jianguo and Li, Pan},
263 | journal={Proceedings of the VLDB Endowment},
264 | volume={15},
265 | number={11},
266 | pages={2788-2796},
267 | year={2022}
268 | }
269 | ```
270 |
--------------------------------------------------------------------------------
/dataloaders.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import time
5 | import numpy as np
6 |
7 | import torch
8 | from ogb.linkproppred import PygLinkPropPredDataset
9 | from scipy.sparse import csr_matrix
10 | from torch_sparse import coalesce
11 | from tqdm import tqdm
12 |
13 | from utils import get_pos_neg_edges, np_sampling
14 |
15 |
16 | class DEDataset():
17 | def __init__(self, dataset, mask_ratio=0.05, use_weight=False, use_coalesce=False, use_degree=False,
18 | use_val=False):
19 | self.data = PygLinkPropPredDataset(name=dataset)
20 | self.graph = self.data[0]
21 | self.split_edge = self.data.get_edge_split()
22 | self.mask_ratio = mask_ratio
23 | self.use_degree = use_degree
24 | self.use_weight = (use_weight and 'edge_weight' in self.graph)
25 | self.use_coalesce = use_coalesce
26 | self.use_val = use_val
27 | self.gtype = 'Homogeneous'
28 |
29 | if 'x' in self.graph:
30 | self.num_nodes, self.num_feature = self.graph['x'].shape
31 | else:
32 | self.num_nodes, self.num_feature = len(torch.unique(self.graph['edge_index'])), None
33 |
34 | if 'source_node' in self.split_edge['train']:
35 | self.directed = True
36 | self.train_edge = self.graph['edge_index'].t()
37 | else:
38 | self.directed = False
39 | self.train_edge = self.split_edge['train']['edge']
40 |
41 | if use_weight:
42 | self.train_weight = self.split_edge['train']['weight']
43 | if use_coalesce:
44 | train_edge_col, self.train_weight = coalesce(self.train_edge.t(), self.train_weight, self.num_nodes,
45 | self.num_nodes)
46 | self.train_edge = train_edge_col.t()
47 | self.train_wmax = max(self.train_weight)
48 | else:
49 | self.train_weight = None
50 | # must put after coalesce
51 | self.len_train = self.train_edge.shape[0]
52 |
53 | def process(self, logger):
54 | logger.info(f'{self.data.meta_info}\nKeys: {self.graph.keys}')
55 | logger.info(
56 | f'node size {self.num_nodes}, feature dim {self.num_feature}, edge size {self.len_train} with mask ratio {self.mask_ratio}')
57 | logger.info(
58 | f'use_weight {self.use_weight}, use_coalesce {self.use_coalesce}, use_degree {self.use_degree}, use_val {self.use_val}')
59 |
60 | self.num_pos = int(self.len_train * self.mask_ratio)
61 | idx = np.random.permutation(self.len_train)
62 | # pos sample edges masked for training, observed edges for structural features
63 | self.pos_edge, obsrv_edge = self.train_edge[idx[:self.num_pos]], self.train_edge[idx[self.num_pos:]]
64 | val_edge = self.train_edge
65 | self.val_nodes = torch.unique(self.train_edge).tolist()
66 |
67 | if self.use_weight:
68 | pos_e_weight = self.train_weight[idx[:self.num_pos]]
69 | obsrv_e_weight = self.train_weight[idx[self.num_pos:]]
70 | val_e_weight = self.train_weight
71 | else:
72 | pos_e_weight = np.ones(self.num_pos, dtype=int)
73 | obsrv_e_weight = np.ones(self.len_train - self.num_pos, dtype=int)
74 | val_e_weight = np.ones(self.len_train, dtype=int)
75 |
76 | if self.use_val:
77 | # collab allows using valid edges for training
78 | obsrv_edge = torch.cat([obsrv_edge, self.split_edge['valid']['edge']])
79 | full_edge = torch.cat([self.train_edge, self.split_edge['valid']['edge']], dim=0)
80 | self.test_nodes = torch.unique(full_edge).tolist()
81 | if self.use_weight:
82 | obsrv_e_weight = torch.cat([self.train_weight[idx[self.num_pos:]], self.split_edge['valid']['weight']])
83 | full_e_weight = torch.cat([self.train_weight, self.split_edge['valid']['weight']], dim=0)
84 | if self.use_coalesce:
85 | obsrv_edge_col, obsrv_e_weight = coalesce(obsrv_edge.t(), obsrv_e_weight, self.num_nodes,
86 | self.num_nodes)
87 | obsrv_edge = obsrv_edge_col.t()
88 | full_edge_col, full_e_weight = coalesce(full_edge.t(), full_e_weight, self.num_nodes,
89 | self.num_nodes)
90 | full_edge = full_edge_col.t()
91 | self.full_wmax = max(full_e_weight)
92 | else:
93 | obsrv_e_weight = np.ones(obsrv_edge.shape[0], dtype=int)
94 | full_e_weight = np.ones(full_edge.shape[0], dtype=int)
95 | else:
96 | full_edge, full_e_weight = self.train_edge, self.train_weight
97 | self.test_nodes = self.val_nodes
98 |
99 | # load observed graph and save as a CSR sparse matrix
100 | max_obsrv_idx = torch.max(obsrv_edge).item()
101 | net_obsrv = csr_matrix((obsrv_e_weight, (obsrv_edge[:, 0].numpy(), obsrv_edge[:, 1].numpy())),
102 | shape=(max_obsrv_idx + 1, max_obsrv_idx + 1))
103 | G_obsrv = net_obsrv + net_obsrv.T
104 | assert sum(G_obsrv.diagonal()) == 0
105 |
106 | # subgraph for training(5 % edges, pos edges)
107 | max_pos_idx = torch.max(self.pos_edge).item()
108 | net_pos = csr_matrix((pos_e_weight, (self.pos_edge[:, 0].numpy(), self.pos_edge[:, 1].numpy())),
109 | shape=(max_pos_idx + 1, max_pos_idx + 1))
110 | G_pos = net_pos + net_pos.T
111 | assert sum(G_pos.diagonal()) == 0
112 |
113 | max_val_idx = torch.max(val_edge).item()
114 | net_val = csr_matrix((val_e_weight, (val_edge[:, 0].numpy(), val_edge[:, 1].numpy())),
115 | shape=(max_val_idx + 1, max_val_idx + 1))
116 | G_val = net_val + net_val.T
117 | assert sum(G_val.diagonal()) == 0
118 |
119 | if self.use_val:
120 | max_full_idx = torch.max(full_edge).item()
121 | net_full = csr_matrix((full_e_weight, (full_edge[:, 0].numpy(), full_edge[:, 1].numpy())),
122 | shape=(max_full_idx + 1, max_full_idx + 1))
123 | G_full = net_full + net_full.transpose()
124 | assert sum(G_full.diagonal()) == 0
125 | else:
126 | G_full = G_val
127 |
128 | self.degree = np.expand_dims(np.log(G_full.getnnz(axis=1) + 1), 1).astype(
129 | np.float32) if self.use_degree else None
130 |
131 | # sparsity of graph
132 | logger.info(f'Sparsity of loaded graph {G_obsrv.getnnz() / (max_obsrv_idx + 1) ** 2}')
133 | # statistic of graph
134 | logger.info(
135 | f'Observed subgraph with {np.sum(G_obsrv.getnnz(axis=1) > 0)} nodes and {int(G_obsrv.nnz / 2)} edges;')
136 | logger.info(f'Training subgraph with {np.sum(G_pos.getnnz(axis=1) > 0)} nodes and {int(G_pos.nnz / 2)} edges.')
137 |
138 | self.data, self.graph = None, None
139 |
140 | return {'pos': G_pos, 'train': G_obsrv, 'val': G_val, 'test': G_full}
141 |
142 |
143 | class DE_Hetro_Dataset():
144 | def __init__(self, dataset, relation, mask_ratio=0.05):
145 | self.data = torch.load(f'./dataset/{dataset}_{relation}.pl')
146 | self.split_edge = self.data['split_edge']
147 | self.node_type = list(self.data['num_nodes_dict'])
148 | self.mask_ratio = mask_ratio
149 | rel_key = ('author', 'writes', 'paper') if relation == 'cite' else ('paper', 'cites', 'paper')
150 | self.obsrv_edge = self.data['edge_index'][rel_key]
151 | self.split_edge = self.data['split_edge']
152 | self.gtype = 'Heterogeneous' if relation == 'cite' else 'Homogeneous'
153 |
154 | if 'x' in self.data:
155 | self.num_nodes, self.num_feature = self.data['x'].shape
156 | else:
157 | self.num_nodes, self.num_feature = self.obsrv_edge.unique().size(0), None
158 |
159 | if 'source_node' in self.split_edge['train']:
160 | self.directed = True
161 | self.train_edge = self.graph['edge_index'].t()
162 | else:
163 | self.directed = False
164 | self.train_edge = self.split_edge['train']['edge']
165 |
166 | self.len_train = self.train_edge.shape[0]
167 |
168 | def process(self, logger):
169 | logger.info(
170 | f'node size {self.num_nodes}, feature dim {self.num_feature}, edge size {self.len_train} with mask ratio {self.mask_ratio}')
171 |
172 | self.num_pos = int(self.len_train * self.mask_ratio)
173 | idx = np.random.permutation(self.len_train)
174 | # pos sample edges masked for training, observed edges for structural features
175 | self.pos_edge, obsrv_edge = self.train_edge[idx[:self.num_pos]], torch.cat(
176 | [self.train_edge[idx[self.num_pos:]], self.obsrv_edge])
177 | val_edge = torch.cat([self.train_edge, self.obsrv_edge])
178 | len_redge = len(self.obsrv_edge)
179 |
180 | pos_e_weight = np.ones(self.num_pos, dtype=int)
181 | obsrv_e_weight = np.ones(self.len_train - self.num_pos + len_redge, dtype=int)
182 | val_e_weight = np.ones(self.len_train + len_redge, dtype=int)
183 |
184 | # load observed graph and save as a CSR sparse matrix
185 | max_obsrv_idx = torch.max(obsrv_edge).item()
186 | net_obsrv = csr_matrix((obsrv_e_weight, (obsrv_edge[:, 0].numpy(), obsrv_edge[:, 1].numpy())),
187 | shape=(max_obsrv_idx + 1, max_obsrv_idx + 1))
188 | G_obsrv = net_obsrv + net_obsrv.T
189 | assert sum(G_obsrv.diagonal()) == 0
190 |
191 | # subgraph for training(5 % edges, pos edges)
192 | max_pos_idx = torch.max(self.pos_edge).item()
193 | net_pos = csr_matrix((pos_e_weight, (self.pos_edge[:, 0].numpy(), self.pos_edge[:, 1].numpy())),
194 | shape=(max_pos_idx + 1, max_pos_idx + 1))
195 | G_pos = net_pos + net_pos.T
196 | assert sum(G_pos.diagonal()) == 0
197 |
198 | max_val_idx = torch.max(val_edge).item()
199 | net_val = csr_matrix((val_e_weight, (val_edge[:, 0].numpy(), val_edge[:, 1].numpy())),
200 | shape=(max_val_idx + 1, max_val_idx + 1))
201 | G_val = net_val + net_val.T
202 | assert sum(G_val.diagonal()) == 0
203 |
204 | G_full = G_val
205 | # sparsity of graph
206 | logger.info(f'Sparsity of loaded graph {G_obsrv.getnnz() / (max_obsrv_idx + 1) ** 2}')
207 | # statistic of graph
208 | logger.info(
209 | f'Observed subgraph with {np.sum(G_obsrv.getnnz(axis=1) > 0)} nodes and {int(G_obsrv.nnz / 2)} edges;')
210 | logger.info(f'Training subgraph with {np.sum(G_pos.getnnz(axis=1) > 0)} nodes and {int(G_pos.nnz / 2)} edges.')
211 |
212 | self.data = None
213 | return {'pos': G_pos, 'train': G_obsrv, 'val': G_val, 'test': G_full}
214 |
215 |
216 | class DE_Hyper_Dataset():
217 | def __init__(self, dataset, mask_ratio=0.6):
218 | self.data = torch.load(f'./dataset/{dataset}.pl')
219 | self.obsrv_edge = torch.from_numpy(self.data['edge_index'])
220 | self.num_tup = len(self.data['triplets'])
221 | self.mask_ratio = mask_ratio
222 | self.split_edge = self.data['triplets']
223 | self.gtype = 'Hypergraph'
224 |
225 | if 'x' in self.data:
226 | self.num_nodes, self.num_feature = self.data['x'].shape
227 | else:
228 | self.num_nodes, self.num_feature = self.obsrv_edge.unique().size(0), None
229 |
230 | def get_edge_split(self, ratio, k=1000, seed=2021):
231 | np.random.seed(seed)
232 | tuples = torch.from_numpy(self.data['triplets'])
233 | idx = np.random.permutation(self.num_tup)
234 | num_train = int(ratio * self.num_tup)
235 | split_idx = {'train': {'hedge': tuples[idx[:num_train]]}}
236 | val_idx, test_idx = np.split(idx[num_train:], 2)
237 | split_idx['valid'], split_idx['test'] = {'hedge': tuples[val_idx]}, {'hedge': tuples[test_idx]}
238 | node_neg = torch.randint(torch.max(tuples), (len(val_idx), k))
239 | split_idx['valid']['hedge_neg'] = torch.cat(
240 | [split_idx['valid']['hedge'][:, :2].repeat(1, k).view(-1, 2).t(), node_neg.view(1, -1)]).t()
241 | split_idx['test']['hedge_neg'] = torch.cat(
242 | [split_idx['test']['hedge'][:, :2].repeat(1, k).view(-1, 2).t(), node_neg.view(1, -1)]).t()
243 | return split_idx
244 |
245 | def process(self, logger):
246 | logger.info(
247 | f'node size {self.num_nodes}, feature dim {self.num_feature}, edge size {self.num_tup} with mask ratio {self.mask_ratio}')
248 | obsrv_edge = self.obsrv_edge
249 |
250 | # load observed graph and save as a CSR sparse matrix
251 | max_obsrv_idx = torch.max(obsrv_edge).item()
252 | obsrv_e_weight = np.ones(len(obsrv_edge), dtype=int)
253 | net_obsrv = csr_matrix((obsrv_e_weight, (obsrv_edge[:, 0].numpy(), obsrv_edge[:, 1].numpy())),
254 | shape=(max_obsrv_idx + 1, max_obsrv_idx + 1))
255 | G_enc = net_obsrv + net_obsrv.T
256 | assert sum(G_enc.diagonal()) == 0
257 |
258 | # sparsity of graph
259 | logger.info(f'Sparsity of loaded graph {G_enc.getnnz() / (max_obsrv_idx + 1) ** 2}')
260 | # statistic of graph
261 | logger.info(f'Observed subgraph with {np.sum(G_enc.getnnz(axis=1) > 0)} nodes and {int(G_enc.nnz / 2)} edges;')
262 |
263 | return G_enc
264 |
265 |
266 | def gen_dataset(dataset, graphs, args, bsize=10000):
267 | G_val, G_full = graphs['val'], graphs['test']
268 |
269 | keep_neg = False if 'ppa' not in args.dataset else True
270 |
271 | test_pos_edge, test_neg_edge = get_pos_neg_edges('test', dataset.split_edge, ratio=args.test_ratio,
272 | keep_neg=keep_neg)
273 | val_pos_edge, val_neg_edge = get_pos_neg_edges('valid', dataset.split_edge, ratio=args.valid_ratio,
274 | keep_neg=keep_neg)
275 |
276 | inf_set = {'test': {}, 'val': {}}
277 |
278 | if args.metric == 'mrr':
279 | inf_set['test']['E'] = torch.cat([test_pos_edge, test_neg_edge], dim=1).t()
280 | inf_set['val']['E'] = torch.cat([val_pos_edge, val_neg_edge], dim=1).t()
281 | inf_set['test']['num_pos'], inf_set['val']['num_pos'] = test_pos_edge.shape[1], val_pos_edge.shape[1]
282 | inf_set['test']['num_neg'], inf_set['val']['num_neg'] = test_neg_edge.shape[1] // inf_set['test']['num_pos'], \
283 | val_neg_edge.shape[1] // inf_set['val']['num_pos']
284 | elif 'Hit' in args.metric:
285 | inf_set['test']['E'] = torch.cat([test_neg_edge, test_pos_edge], dim=1).t()
286 | inf_set['val']['E'] = torch.cat([val_neg_edge, val_pos_edge], dim=1).t()
287 | inf_set['test']['num_pos'], inf_set['val']['num_pos'] = test_pos_edge.shape[1], val_pos_edge.shape[1]
288 | inf_set['test']['num_neg'], inf_set['val']['num_neg'] = test_neg_edge.shape[1], val_neg_edge.shape[1]
289 | else:
290 | raise NotImplementedError
291 |
292 | if args.use_val:
293 | val_dict = np_sampling({}, G_val.indptr, G_val.indices, bsize=bsize,
294 | target=torch.unique(inf_set['val']['E']).tolist(), num_walks=args.num_walk,
295 | num_steps=args.num_step - 1)
296 | test_dict = np_sampling({}, G_full.indptr, G_full.indices, bsize=bsize,
297 | target=torch.unique(inf_set['test']['E']).tolist(), num_walks=args.num_walk,
298 | num_steps=args.num_step - 1)
299 | else:
300 | val_dict = test_dict = np_sampling({}, G_val.indptr, G_val.indices, bsize=bsize,
301 | target=torch.unique(
302 | torch.cat([inf_set['val']['E'], inf_set['test']['E']])).tolist(),
303 | num_walks=args.num_walk, num_steps=args.num_step - 1)
304 |
305 | if not args.use_feature:
306 | if args.use_degree:
307 | inf_set['X'] = torch.from_numpy(dataset.degree)
308 | elif args.use_htype:
309 | inf_set['X'] = dataset.node_map
310 | else:
311 | inf_set['X'] = None
312 | else:
313 | inf_set['X'] = dataset.graph['x']
314 | args.x_dim = inf_set['X'].shape[-1]
315 |
316 | args.w_max = dataset.train_wmax if args.use_weight else None
317 |
318 | return test_dict, val_dict, inf_set
319 |
320 |
321 | def gen_dataset_hyper(dataset, G_enc, args, bsize=10000):
322 | test_pos_edge, test_neg_edge = get_pos_neg_edges('test', dataset.split_edge, ratio=args.test_ratio)
323 | val_pos_edge, val_neg_edge = get_pos_neg_edges('valid', dataset.split_edge, ratio=args.valid_ratio)
324 |
325 | inf_set = {'test': {}, 'val': {}}
326 |
327 | if args.metric == 'mrr':
328 | inf_set['test']['E'] = torch.cat([test_pos_edge, test_neg_edge])
329 | inf_set['val']['E'] = torch.cat([val_pos_edge, val_neg_edge])
330 | inf_set['test']['num_pos'], inf_set['val']['num_pos'] = test_pos_edge.shape[0], val_pos_edge.shape[0]
331 | inf_set['test']['num_neg'], inf_set['val']['num_neg'] = test_neg_edge.shape[0] // inf_set['test']['num_pos'], \
332 | val_neg_edge.shape[0] // inf_set['val']['num_pos']
333 | else:
334 | raise NotImplementedError
335 |
336 | inf_dict = np_sampling({}, G_enc.indptr, G_enc.indices,
337 | bsize=bsize,
338 | target=torch.unique(torch.cat([inf_set['val']['E'], inf_set['test']['E']])).tolist(),
339 | num_walks=args.num_walk,
340 | num_steps=args.num_step - 1)
341 |
342 | if not args.use_feature:
343 | inf_set['X'] = None
344 | else:
345 | inf_set['X'] = dataset.graph['x']
346 | args.x_dim = inf_set['X'].shape[-1]
347 |
348 | return inf_dict, inf_set
349 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: sgrl_env
2 | channels:
3 | - pyg
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - absl-py=1.3.0=py310h06a4308_0
10 | - aiohttp=3.8.3=py310h5eee18b_0
11 | - aiosignal=1.2.0=pyhd3eb1b0_0
12 | - async-timeout=4.0.2=py310h06a4308_0
13 | - attrs=22.1.0=py310h06a4308_0
14 | - blas=1.0=mkl
15 | - blinker=1.4=py310h06a4308_0
16 | - brotlipy=0.7.0=py310h7f8727e_1002
17 | - bzip2=1.0.8=h7b6447c_0
18 | - c-ares=1.18.1=h7f8727e_0
19 | - ca-certificates=2023.01.10=h06a4308_0
20 | - cachetools=4.2.2=pyhd3eb1b0_0
21 | - certifi=2022.12.7=py310h06a4308_0
22 | - cffi=1.15.1=py310h5eee18b_3
23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
24 | - click=8.0.4=py310h06a4308_0
25 | - cryptography=38.0.4=py310h9ce1e76_0
26 | - cudatoolkit=11.3.1=h2bc3f7f_2
27 | - ffmpeg=4.3=hf484d3e_0
28 | - fftw=3.3.9=h27cfd23_1
29 | - flit-core=3.6.0=pyhd3eb1b0_0
30 | - freetype=2.12.1=h4a9f257_0
31 | - frozenlist=1.3.3=py310h5eee18b_0
32 | - giflib=5.2.1=h7b6447c_0
33 | - gmp=6.2.1=h295c915_3
34 | - gnutls=3.6.15=he1e5248_0
35 | - google-auth=2.6.0=pyhd3eb1b0_0
36 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
37 | - grpcio=1.42.0=py310hce63b2e_0
38 | - idna=3.4=py310h06a4308_0
39 | - intel-openmp=2021.4.0=h06a4308_3561
40 | - jinja2=3.1.2=py310h06a4308_0
41 | - joblib=1.1.1=py310h06a4308_0
42 | - jpeg=9e=h7f8727e_0
43 | - lame=3.100=h7b6447c_0
44 | - lcms2=2.12=h3be6417_0
45 | - ld_impl_linux-64=2.38=h1181459_1
46 | - lerc=3.0=h295c915_0
47 | - libdeflate=1.8=h7f8727e_5
48 | - libffi=3.4.2=h6a678d5_6
49 | - libgcc-ng=11.2.0=h1234567_1
50 | - libgfortran-ng=11.2.0=h00389a5_1
51 | - libgfortran5=11.2.0=h1234567_1
52 | - libgomp=11.2.0=h1234567_1
53 | - libiconv=1.16=h7f8727e_2
54 | - libidn2=2.3.2=h7f8727e_0
55 | - libpng=1.6.37=hbc83047_0
56 | - libprotobuf=3.20.1=h4ff587b_0
57 | - libstdcxx-ng=11.2.0=h1234567_1
58 | - libtasn1=4.16.0=h27cfd23_0
59 | - libtiff=4.5.0=hecacb30_0
60 | - libunistring=0.9.10=h27cfd23_0
61 | - libuuid=1.41.5=h5eee18b_0
62 | - libwebp=1.2.4=h11a3e52_0
63 | - libwebp-base=1.2.4=h5eee18b_0
64 | - lz4-c=1.9.4=h6a678d5_0
65 | - markdown=3.4.1=py310h06a4308_0
66 | - markupsafe=2.1.1=py310h7f8727e_0
67 | - mkl=2021.4.0=h06a4308_640
68 | - mkl-service=2.4.0=py310h7f8727e_0
69 | - mkl_fft=1.3.1=py310hd6ae3a3_0
70 | - mkl_random=1.2.2=py310h00e6091_0
71 | - multidict=6.0.2=py310h5eee18b_0
72 | - ncurses=6.3=h5eee18b_3
73 | - nettle=3.7.3=hbbd107a_1
74 | - numpy=1.23.5=py310hd5efca6_0
75 | - numpy-base=1.23.5=py310h8e6c178_0
76 | - oauthlib=3.2.1=py310h06a4308_0
77 | - openh264=2.1.1=h4ff587b_0
78 | - openssl=1.1.1s=h7f8727e_0
79 | - pillow=9.3.0=py310hace64e9_1
80 | - pip=22.3.1=py310h06a4308_0
81 | - protobuf=3.20.1=py310h295c915_0
82 | - psutil=5.9.0=py310h5eee18b_0
83 | - pyasn1=0.4.8=pyhd3eb1b0_0
84 | - pyasn1-modules=0.2.8=py_0
85 | - pycparser=2.21=pyhd3eb1b0_0
86 | - pyg=2.2.0=py310_torch_1.12.0_cu113
87 | - pyjwt=2.4.0=py310h06a4308_0
88 | - pyopenssl=22.0.0=pyhd3eb1b0_0
89 | - pyparsing=3.0.9=py310h06a4308_0
90 | - pysocks=1.7.1=py310h06a4308_0
91 | - python=3.10.9=h7a1cb2a_0
92 | - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0
93 | - pytorch-cluster=1.6.0=py310_torch_1.12.0_cu113
94 | - pytorch-mutex=1.0=cuda
95 | - pytorch-scatter=2.1.0=py310_torch_1.12.0_cu113
96 | - pytorch-sparse=0.6.16=py310_torch_1.12.0_cu113
97 | - readline=8.2=h5eee18b_0
98 | - requests=2.28.1=py310h06a4308_0
99 | - requests-oauthlib=1.3.0=py_0
100 | - rsa=4.7.2=pyhd3eb1b0_1
101 | - scikit-learn=1.2.0=py310h6a678d5_0
102 | - scipy=1.9.3=py310hd5efca6_0
103 | - setuptools=65.6.3=py310h06a4308_0
104 | - six=1.16.0=pyhd3eb1b0_1
105 | - sqlite=3.40.1=h5082296_0
106 | - tensorboard=2.10.0=py310h06a4308_0
107 | - tensorboard-data-server=0.6.1=py310h52d8a92_0
108 | - tensorboard-plugin-wit=1.8.1=py310h06a4308_0
109 | - threadpoolctl=2.2.0=pyh0d69192_0
110 | - tk=8.6.12=h1ccaba5_0
111 | - torchaudio=0.12.1=py310_cu113
112 | - torchvision=0.13.1=py310_cu113
113 | - tqdm=4.64.1=py310h06a4308_0
114 | - typing_extensions=4.4.0=py310h06a4308_0
115 | - tzdata=2022g=h04d1e81_0
116 | - urllib3=1.26.14=py310h06a4308_0
117 | - werkzeug=2.2.2=py310h06a4308_0
118 | - wheel=0.37.1=pyhd3eb1b0_0
119 | - xz=5.2.10=h5eee18b_1
120 | - yarl=1.8.1=py310h5eee18b_0
121 | - zlib=1.2.13=h5eee18b_0
122 | - zstd=1.5.2=ha4553b6_0
123 | - pip:
124 | - fastremap==1.13.3
125 | - littleutils==0.2.2
126 | - llvmlite==0.39.1
127 | - numba==0.56.4
128 | - ogb==1.3.5
129 | - outdated==0.2.2
130 | - pandas==1.5.3
131 | - pyg-lib==0.1.0+pt112cu113
132 | - python-dateutil==2.8.2
133 | - pytz==2022.7.1
134 | - streamtologger==2017.1
135 |
--------------------------------------------------------------------------------
/log.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import logging
5 | import os
6 | import socket
7 | import time
8 |
9 | import numpy as np
10 | import streamtologger
11 |
12 |
13 | def set_up_log(args, sys_argv):
14 | log_dir = args.log_dir
15 | save_dir = os.path.join(args.res_dir, 'model', args.dataset)
16 | dataset_log_dir = os.path.join(log_dir, args.dataset)
17 | if not os.path.exists(save_dir):
18 | os.makedirs(save_dir)
19 | if not os.path.exists(log_dir):
20 | os.mkdir(log_dir)
21 | if not os.path.exists(dataset_log_dir):
22 | os.mkdir(dataset_log_dir)
23 |
24 | args.stamp = time.strftime('%m%d%y_%H%M%S')
25 | file_path = os.path.join(dataset_log_dir, f"{args.stamp}.log")
26 |
27 | logging.basicConfig(level=logging.INFO)
28 | logger = logging.getLogger()
29 | logger.setLevel(logging.DEBUG)
30 | fh = logging.FileHandler(file_path)
31 | fh.setLevel(logging.DEBUG)
32 | ch = logging.StreamHandler()
33 | ch.setLevel(logging.WARN)
34 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
35 | fh.setFormatter(formatter)
36 | ch.setFormatter(formatter)
37 | logger.addHandler(fh)
38 | logger.addHandler(ch)
39 | logger.info('Create log file at {}'.format(file_path))
40 | logger.info('Command line executed: python ' + ' '.join(sys_argv))
41 | logger.info('Full args parsed:')
42 | logger.info(args)
43 | if args.debug:
44 | streamtologger.redirect(target=logger)
45 | return logger
46 |
47 |
48 | def save_performance_result(args, logger, metrics, repeat=0):
49 | summary_file = args.summary_file
50 | if summary_file != 'test':
51 | summary_file = os.path.join(args.log_dir, summary_file)
52 | else:
53 | return
54 | dataset = args.dataset
55 | val_metric, no_val_metric = metrics
56 | model_name = '-'.join([args.model, str(args.num_step), str(args.num_walk), str(args.K)])
57 | seed = args.seed
58 | log_name = os.path.split(logger.handlers[1].baseFilename)[-1]
59 | server = socket.gethostname()
60 | line = '\t'.join(
61 | [dataset, model_name, str(seed), str(round(val_metric, 4)), f'R{repeat}', str(round(no_val_metric, 4)),
62 | log_name, server]) + '\n'
63 | try:
64 | with open(summary_file, 'a') as f:
65 | f.write(line)
66 | except:
67 | raise Warning(f'Unable to write back summary file at {summary_file}.')
68 |
69 |
70 | def save_to_file(dic, args, logger, dtype):
71 | save_dict = dic.copy()
72 | flag = 'W' if args.use_weight else 'R'
73 |
74 | if args.save:
75 | if args.use_val and dtype == 'test':
76 | file_name = f'{args.res_dir}/dict/{args.dataset}_{dtype}_{args.num_step}_{args.num_walk}_{flag}_uval.pt'
77 | else:
78 | file_name = f'{args.res_dir}/dict/{args.dataset}_{dtype}_{args.num_step}_{args.num_walk}_{flag}_wo.pt'
79 | if not os.path.exists(file_name):
80 | save_dict.pop('num')
81 | keys, values = list(save_dict.keys()), list(save_dict.values())
82 | walks, ids, freqs = zip(*values)
83 | np.savez(file_name, X=keys, Y=ids, W=walks, F=freqs)
84 | logger.info(f'Saved {dtype} set to {file_name}')
85 | else:
86 | logger.info(f'File exists, {dtype} skipped.')
87 | else:
88 | logger.info(f'Converted {dtype} set to tensor.')
89 | save_dict['flag'] = False
90 | return save_dict
91 |
92 |
93 | def log_record(logger, tb, out, dic, b_time, batchIdx):
94 | mode, metric, auc = out['mode'], out['metric'], out['auc']
95 | dt = time.time() - b_time
96 | if tb is not None:
97 | tb.add_scalar(f"AUC/{mode}", auc, batchIdx)
98 | key_metric, key_auc = f'{mode}_{metric}', f'{mode}_AUC'
99 |
100 | if metric == 'mrr':
101 | out_metric = out['mrr_list'].mean()
102 | if tb is not None:
103 | tb.add_scalar(f"MRR/{mode}", out_metric, batchIdx)
104 | logger.info(f"AUC/{mode}: {auc:.4f}, MRR {out_metric:.4f} # {len(out['mrr_list'])} Time {dt:.2f}s")
105 | dic[key_metric].append(out_metric.item())
106 | elif 'Hit' in metric:
107 | if tb is not None:
108 | tb.add_scalars(f"Hits/{mode}", out['hits'], batchIdx)
109 | hits = ' '.join([f'{k}: {v:.4f}' for k, v in out['hits'].items()])
110 | logger.info(f"AUC/{mode}: {auc:.4f}, {hits} # {out['num_pos']} Time {dt:.2f}s")
111 | dic[key_metric].append(out['hits'][metric])
112 | else:
113 | raise NotImplementedError
114 |
115 | dic[key_auc].append(auc)
116 | val_metric = f'val_{metric}'
117 | len_val = len(dic[val_metric])
118 | if mode == 'test':
119 | len_test = len(dic[key_metric])
120 | if len_val > len_test:
121 | idx = np.argmax(dic[val_metric][-len_test:]) - len_test
122 | else:
123 | idx = np.argmax(dic[val_metric])
124 | logger.info(f'Best {metric}: val {dic[val_metric][idx]:.4f} test {dic[key_metric][idx]:.4f}')
125 | if idx == (len_test - 1):
126 | return True
127 | elif mode == 'val':
128 | if np.argmax(dic[val_metric]) == (len_val - 1):
129 | return True
130 | return False
131 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 |
6 | import torch
7 | from surel_gacc import run_sample
8 | from ogb.linkproppred import Evaluator
9 | from torch.utils.tensorboard import SummaryWriter
10 | from torch_geometric.utils import subgraph
11 |
12 | from dataloaders import *
13 | from log import *
14 | from models.model import Net
15 | from train import *
16 | from utils import *
17 | import sys
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser('Interface for SUREL framework')
21 |
22 | # general model and training setting
23 | parser.add_argument('--dataset', type=str, default='ogbl-citation2', help='dataset name',
24 | choices=['ogbl-ppa', 'ogbl-ddi', 'ogbl-citation2', 'ogbl-collab', 'mag'])
25 | parser.add_argument('--model', type=str, default='RNN', help='base model to use',
26 | choices=['RNN', 'MLP', 'Transformer', 'GNN'])
27 | parser.add_argument('--layers', type=int, default=2, help='number of layers')
28 | parser.add_argument('--hidden_dim', type=int, default=64, help='hidden dimension')
29 | parser.add_argument('--x_dim', type=int, default=0, help='dim of raw node features')
30 | parser.add_argument('--data_usage', type=float, default=1.0, help='use partial dataset')
31 | parser.add_argument('--train_ratio', type=float, default=0.05, help='mask partial edges for training')
32 | parser.add_argument('--valid_ratio', type=float, default=0.1, help='use partial valid set')
33 | parser.add_argument('--test_ratio', type=float, default=1.0, help='use partial test set')
34 | parser.add_argument('--metric', type=str, default='mrr', help='metric for evaluating performance',
35 | choices=['auc', 'mrr', 'hit'])
36 | parser.add_argument('--seed', type=int, default=0, help='seed to initialize all the random modules')
37 | parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
38 | parser.add_argument('--nthread', type=int, default=16, help='number of thread')
39 |
40 | # features and positional encoding
41 | parser.add_argument('--B_size', type=int, default=1500, help='set size of train sampling')
42 | parser.add_argument('--num_walk', type=int, default=100, help='total number of random walks')
43 | parser.add_argument('--num_step', type=int, default=4, help='total steps of random walk')
44 | parser.add_argument('--k', type=int, default=50, help='number of paired negative edges')
45 | parser.add_argument('--directed', type=bool, default=False, help='whether to treat the graph as directed')
46 | parser.add_argument('--use_feature', action='store_true', help='whether to use raw features as input')
47 | parser.add_argument('--use_weight', action='store_true', help='whether to use edge weight as input')
48 | parser.add_argument('--use_degree', action='store_true', help='whether to use node degree as input')
49 | parser.add_argument('--use_htype', action='store_true', help='whether to use node type as input')
50 | parser.add_argument('--use_val', action='store_true', help='whether to use val as input')
51 | parser.add_argument('--norm', type=str, default='all', help='method of normalization')
52 |
53 | # model training
54 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use')
55 | parser.add_argument('--rtest', type=int, default=499, help='step start to test')
56 | parser.add_argument('--eval_steps', type=int, default=100, help='number of steps to test')
57 | parser.add_argument('--batch_size', type=int, default=32, help='mini-batch size (train)')
58 | parser.add_argument('--batch_num', type=int, default=2000, help='mini-batch size (test)')
59 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
60 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')
61 | parser.add_argument('--l2', type=float, default=0., help='l2 regularization (weight decay)')
62 | parser.add_argument('--patience', type=int, default=5, help='early stopping steps')
63 | parser.add_argument('--repeat', type=int, default=1, help='number of training instances to repeat')
64 |
65 | # logging & debug
66 | parser.add_argument('--log_dir', type=str, default='./log/', help='log directory')
67 | parser.add_argument('--res_dir', type=str, default='./dataset/save', help='resource directory')
68 | parser.add_argument('--stamp', type=str, default='', help='time stamp')
69 | parser.add_argument('--summary_file', type=str, default='result_summary.log',
70 | help='brief summary of training results')
71 | parser.add_argument('--debug', default=False, action='store_true', help='whether to use debug mode')
72 | parser.add_argument('--load_dict', default=False, action='store_true', help='whether to load RPE from files')
73 | parser.add_argument('--save', default=False, action='store_true', help='whether to save RPE to files')
74 | parser.add_argument('--load_model', default=False, action='store_true',
75 | help='whether to load saved model from files')
76 | parser.add_argument('--memo', type=str, help='notes')
77 |
78 | sys_argv = sys.argv
79 | try:
80 | args = parser.parse_args()
81 | except:
82 | parser.print_help()
83 | sys.exit(0)
84 |
85 | set_random_seed(args)
86 |
87 | # customized for each dataset
88 | if 'ddi' in args.dataset:
89 | args.metric = 'Hits@20'
90 | elif 'collab' in args.dataset:
91 | args.metric = 'Hits@50'
92 | elif 'ppa' in args.dataset:
93 | args.metric = 'Hits@100'
94 | elif 'citation' in args.dataset:
95 | args.metric = 'mrr'
96 | else:
97 | raise NotImplementedError
98 |
99 | # setup logger and tensorboard
100 | logger = set_up_log(args, sys_argv)
101 | if args.nthread > 0:
102 | torch.set_num_threads(args.nthread)
103 | logger.info(f"torch num_threads {torch.get_num_threads()}")
104 | tb = SummaryWriter()
105 |
106 | save_dir = f'{args.res_dir}/model/{args.dataset}'
107 | if not os.path.exists(save_dir):
108 | os.mkdir(save_dir)
109 |
110 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')
111 | prefix = f'{save_dir}/{args.stamp}_{args.num_step}_{args.num_walk}'
112 |
113 | g_class = DEDataset(args.dataset, args.train_ratio,
114 | use_weight=args.use_weight,
115 | use_coalesce=args.use_weight,
116 | use_degree=args.use_degree,
117 | use_val=args.use_val)
118 | evaluator = Evaluator(name=args.dataset)
119 | graphs = g_class.process(logger)
120 |
121 | # edges for negative sampling
122 | T_edge_idx, F_edge_idx = g_class.pos_edge.t().contiguous(), g_class.train_edge.t().contiguous()
123 |
124 | # define model and optim
125 | model = Net(num_layers=args.layers, input_dim=args.num_step, hidden_dim=args.hidden_dim, out_dim=1,
126 | num_walk=args.num_walk, x_dim=args.x_dim, dropout=args.dropout, use_feature=args.use_feature,
127 | use_weight=args.use_weight, use_degree=args.use_degree, use_htype=args.use_htype)
128 | model.to(device)
129 |
130 | logger.info(f'#Model Params {sum(p.numel() for p in model.parameters())}')
131 |
132 | if args.optim == 'adam':
133 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
134 | else:
135 | raise NotImplementedError
136 |
137 | if args.load_model:
138 | load_checkpoint(model, optimizer, filename=prefix)
139 |
140 | test_dict, val_dict, inf_set = gen_dataset(g_class, graphs, args)
141 |
142 | logger.info(
143 | f'Samples: valid {inf_set["val"]["num_pos"]} by {inf_set["val"]["num_neg"]} '
144 | f'test {inf_set["test"]["num_pos"]} by {inf_set["test"]["num_neg"]} metric: {args.metric}')
145 |
146 | G_pos, G_train = graphs['pos'], graphs['train']
147 | num_pos, num_seed, num_cand = len(set(G_pos.indices)), 100, 5
148 |
149 | candidates = G_pos.getnnz(axis=1).argsort()[-num_seed:][::-1]
150 |
151 | rw_dict = {}
152 | B_queues = []
153 |
154 | for r in range(1, args.repeat + 1):
155 | res_dict = {'test_AUC': [], 'val_AUC': [], f'test_{args.metric}': [], f'val_{args.metric}': []}
156 | model.reset_parameters()
157 | logger.info(f'Running Round {r}')
158 | batchIdx, patience = 0, 0
159 | pools = np.copy(candidates)
160 | np.random.shuffle(B_queues)
161 | while True:
162 | if r <= 1:
163 | seeds = np.random.choice(pools, 5, replace=False)
164 | B_queues.append(sorted(run_sample(G_pos.indptr, G_pos.indices, seeds, thld=args.B_size)))
165 | B_pos = B_queues[batchIdx]
166 | B_w = [b for b in B_pos if b not in rw_dict]
167 | if len(B_w) > 0:
168 | walk_set, freqs = run_walk(G_train.indptr, G_train.indices, B_w, num_walks=args.num_walk,
169 | num_steps=args.num_step - 1, replacement=True)
170 | node_id, node_freq = freqs[:, 0], freqs[:, 1]
171 | rw_dict.update(dict(zip(B_w, zip(walk_set, node_id, node_freq))))
172 | else:
173 | if batchIdx >= len(B_queues):
174 | break
175 | else:
176 | B_pos = B_queues[batchIdx]
177 | batchIdx += 1
178 |
179 | # obtain set of walks, node id and DE (counts) from the dictionary
180 | S, K, F = zip(*itemgetter(*B_pos)(rw_dict))
181 | B_pos_edge, _ = subgraph(list(B_pos), T_edge_idx)
182 | B_full_edge, _ = subgraph(list(B_pos), F_edge_idx)
183 | data = gen_sample(np.asarray(S), B_pos, K, B_pos_edge, B_full_edge, inf_set['X'], args, gtype=g_class.gtype)
184 | F = np.concatenate(F)
185 | mF = torch.from_numpy(np.concatenate([[[0] * F.shape[-1]], F])).to(device)
186 | gT = normalization(mF, args)
187 | loss, auc = train(model, optimizer, data, gT)
188 | logger.info(f'Batch {batchIdx}\tW{len(rw_dict)}/D{num_pos}\tLoss: {loss:.4f}, AUC: {auc:.4f}')
189 | tb.add_scalar("Loss/train", loss, batchIdx)
190 | tb.add_scalar("AUC/train", auc, batchIdx)
191 |
192 | if batchIdx > args.rtest and batchIdx % args.eval_steps == 0:
193 | bvtime = time.time()
194 | out = eval_model(model, val_dict, inf_set, args, evaluator, device, mode='val')
195 | if log_record(logger, tb, out, res_dict, bvtime, batchIdx):
196 | patience = 0
197 | bttime = time.time()
198 | out = eval_model(model, test_dict, inf_set, args, evaluator, device, mode='test')
199 | if log_record(logger, tb, out, res_dict, bttime, batchIdx):
200 | checkpoint = {'state_dict': model.state_dict(),
201 | 'optimizer': optimizer.state_dict(),
202 | 'epoch': batchIdx}
203 | save_checkpoint(checkpoint, filename=prefix)
204 | else:
205 | patience += 1
206 |
207 | if patience > args.patience:
208 | break
209 | tb.close()
210 |
--------------------------------------------------------------------------------
/main_hetro.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 |
6 | import torch
7 | from surel_gacc import run_sample
8 | from ogb.linkproppred import Evaluator
9 | from torch.utils.tensorboard import SummaryWriter
10 | from torch_geometric.utils import subgraph
11 |
12 | from dataloaders import *
13 | from log import *
14 | from models.model import Net
15 | from train import *
16 | from utils import *
17 | import sys
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser('Interface for SUREL (Relation Prediction)')
21 |
22 | # general model and training setting
23 | parser.add_argument('--dataset', type=str, default='mag', help='dataset name',
24 | choices=['mag'])
25 | parser.add_argument('--relation', type=str, default='cite', help='relation type',
26 | choices=['write', 'cite'])
27 | parser.add_argument('--model', type=str, default='RNN', help='base model to use',
28 | choices=['RNN', 'MLP', 'Transformer', 'GNN'])
29 | parser.add_argument('--layers', type=int, default=2, help='number of layers')
30 | parser.add_argument('--hidden_dim', type=int, default=64, help='hidden dimension')
31 | parser.add_argument('--x_dim', type=int, default=0, help='dim of raw node features')
32 | parser.add_argument('--data_usage', type=float, default=1.0, help='use partial dataset')
33 | parser.add_argument('--train_ratio', type=float, default=0.05, help='mask partial edges for training')
34 | parser.add_argument('--valid_ratio', type=float, default=0.1, help='use partial valid set')
35 | parser.add_argument('--test_ratio', type=float, default=1.0, help='use partial test set')
36 | parser.add_argument('--metric', type=str, default='mrr', help='metric for evaluating performance',
37 | choices=['auc', 'mrr', 'hit'])
38 | parser.add_argument('--seed', type=int, default=0, help='seed to initialize all the random modules')
39 | parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
40 | parser.add_argument('--nthread', type=int, default=16, help='number of thread')
41 |
42 | # features and positional encoding
43 | parser.add_argument('--B_size', type=int, default=1500, help='set size of train sampling')
44 | parser.add_argument('--num_walk', type=int, default=100, help='total number of random walks')
45 | parser.add_argument('--num_step', type=int, default=4, help='total steps of random walk')
46 | parser.add_argument('--k', type=int, default=50, help='number of paired negative edges')
47 | parser.add_argument('--directed', type=bool, default=False, help='whether to treat the graph as directed')
48 | parser.add_argument('--use_feature', action='store_true', help='whether to use raw features as input')
49 | parser.add_argument('--use_weight', action='store_true', help='whether to use edge weight as input')
50 | parser.add_argument('--use_degree', action='store_true', help='whether to use node degree as input')
51 | parser.add_argument('--use_htype', action='store_true', help='whether to use node type as input')
52 | parser.add_argument('--use_val', action='store_true', help='whether to use val as input')
53 | parser.add_argument('--norm', type=str, default='all', help='method of normalization')
54 |
55 | # model training
56 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use')
57 | parser.add_argument('--rtest', type=int, default=499, help='step start to test')
58 | parser.add_argument('--eval_steps', type=int, default=200, help='number of steps to test')
59 | parser.add_argument('--batch_size', type=int, default=32, help='mini-batch size (train)')
60 | parser.add_argument('--batch_num', type=int, default=2000, help='mini-batch size (test)')
61 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
62 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')
63 | parser.add_argument('--l2', type=float, default=0., help='l2 regularization (weight decay)')
64 | parser.add_argument('--patience', type=int, default=5, help='early stopping steps')
65 | parser.add_argument('--repeat', type=int, default=5, help='number of training instances to repeat')
66 |
67 | # logging & debug
68 | parser.add_argument('--log_dir', type=str, default='./log/', help='log directory')
69 | parser.add_argument('--res_dir', type=str, default='./dataset/save', help='resource directory')
70 | parser.add_argument('--stamp', type=str, default='', help='time stamp')
71 | parser.add_argument('--summary_file', type=str, default='result_summary.log',
72 | help='brief summary of training results')
73 | parser.add_argument('--debug', default=False, action='store_true', help='whether to use debug mode')
74 | parser.add_argument('--save', default=False, action='store_true', help='whether to save RPE to files')
75 | parser.add_argument('--load_model', default=False, action='store_true',
76 | help='whether to load saved model from files')
77 | parser.add_argument('--memo', type=str, help='notes')
78 |
79 | sys_argv = sys.argv
80 | try:
81 | args = parser.parse_args()
82 | except:
83 | parser.print_help()
84 | sys.exit(0)
85 |
86 | # customized for each dataset
87 | if 'mag' in args.dataset:
88 | args.metric = 'mrr'
89 | else:
90 | raise NotImplementedError
91 |
92 | # setup logger and tensorboard
93 | logger = set_up_log(args, sys_argv)
94 | if args.nthread > 0:
95 | torch.set_num_threads(args.nthread)
96 | logger.info(f"torch num_threads {torch.get_num_threads()}")
97 | tb = SummaryWriter()
98 |
99 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')
100 | prefix = f'{args.res_dir}/model/{args.dataset}/{args.stamp}_{args.num_step}_{args.num_walk}'
101 |
102 | g_class = DE_Hetro_Dataset(args.dataset, args.relation)
103 | args.x_dim = len(g_class.node_type)
104 | graphs = g_class.process(logger)
105 |
106 | # edges for negative sampling
107 | T_edge_idx, F_edge_idx = g_class.pos_edge.t().contiguous(), g_class.train_edge.t().contiguous()
108 |
109 | # define model and optim
110 | model = Net(num_layers=args.layers, input_dim=args.num_step, hidden_dim=args.hidden_dim, out_dim=1,
111 | num_walk=args.num_walk, x_dim=args.x_dim, dropout=args.dropout, use_feature=args.use_feature,
112 | use_weight=args.use_weight, use_degree=args.use_degree, use_htype=args.use_htype)
113 | model.to(device)
114 |
115 | if args.optim == 'adam':
116 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
117 | else:
118 | raise NotImplementedError
119 |
120 | evaluator = Evaluator(name=args.dataset) if 'mag' not in args.dataset else Evaluator(name='ogbl-citation2')
121 |
122 | if args.load_model:
123 | load_checkpoint(model, optimizer, filename=prefix)
124 |
125 | test_dict, val_dict, inf_set = gen_dataset(g_class, graphs, args)
126 |
127 | logger.info(
128 | f'Samples: valid {inf_set["val"]["num_pos"]} by {inf_set["val"]["num_neg"]} '
129 | f'test {inf_set["test"]["num_pos"]} by {inf_set["test"]["num_neg"]} metric: {args.metric}')
130 |
131 | G_pos, G_train = graphs['pos'], graphs['train']
132 | num_pos, num_seed, num_cand = len(set(G_pos.indices)), 100, 5
133 |
134 | deg = G_pos.getnnz(axis=1)
135 | candidates = np.concatenate(
136 | [deg[:736389].argsort()[-num_seed // 2:][::-1], deg[736389:].argsort()[-num_seed // 2:][::-1] + 736389])
137 |
138 | rw_dict = {}
139 | B_queues = []
140 |
141 | for r in range(1, args.repeat + 1):
142 | res_dict = {'test_AUC': [], 'val_AUC': [], f'test_{args.metric}': [], f'val_{args.metric}': []}
143 | model.reset_parameters()
144 | logger.info(f'Running Round {r}')
145 | batchIdx, patience = 0, 0
146 | pools = np.copy(candidates)
147 | np.random.shuffle(B_queues)
148 | while True:
149 | if r <= 1:
150 | seeds = np.random.choice(pools, 5, replace=False)
151 | B_queues.append(sorted(run_sample(G_pos.indptr, G_pos.indices, seeds, thld=args.B_size)))
152 | B_pos = B_queues[batchIdx]
153 | B_w = [b for b in B_pos if b not in rw_dict]
154 | if len(B_w) > 0:
155 | walk_set, freqs = run_walk(G_train.indptr, G_train.indices, B_w, num_walks=args.num_walk,
156 | num_steps=args.num_step - 1, replacement=True)
157 | node_id, node_freq = freqs[:, 0], freqs[:, 1]
158 | rw_dict.update(dict(zip(B_w, zip(walk_set, node_id, node_freq))))
159 | else:
160 | if batchIdx >= len(B_queues):
161 | break
162 | else:
163 | B_pos = B_queues[batchIdx]
164 | batchIdx += 1
165 |
166 | # obtain set of walks, node id and RPE (counts) from the dictionary
167 | S, K, F = zip(*itemgetter(*B_pos)(rw_dict))
168 | B_pos_edge, _ = subgraph(list(B_pos), T_edge_idx)
169 | B_full_edge, _ = subgraph(list(B_pos), F_edge_idx)
170 | data = gen_sample(np.asarray(S), B_pos, K, B_pos_edge, B_full_edge, inf_set['X'], args, gtype=g_class.gtype)
171 | F = np.concatenate(F)
172 | mF = torch.from_numpy(np.concatenate([[[0] * F.shape[-1]], F])).to(device)
173 | gT = normalization(mF, args)
174 | loss, auc = train(model, optimizer, data, gT)
175 | logger.info(f'Batch {batchIdx}\tW{len(rw_dict)}/D{num_pos}\tLoss: {loss:.4f}, AUC: {auc:.4f}')
176 | tb.add_scalar("Loss/train", loss, batchIdx)
177 | tb.add_scalar("AUC/train", auc, batchIdx)
178 |
179 | if batchIdx > args.rtest and batchIdx % args.eval_steps == 0:
180 | bvtime = time.time()
181 | out = eval_model(model, val_dict, inf_set, args, evaluator, device, mode='val')
182 | if log_record(logger, tb, out, res_dict, bvtime, batchIdx):
183 | patience = 0
184 | bttime = time.time()
185 | out = eval_model(model, test_dict, inf_set, args, evaluator, device, mode='test')
186 | if log_record(logger, tb, out, res_dict, bttime, batchIdx):
187 | checkpoint = {'state_dict': model.state_dict(),
188 | 'optimizer': optimizer.state_dict(),
189 | 'epoch': batchIdx}
190 | save_checkpoint(checkpoint, filename=prefix)
191 | else:
192 | patience += 1
193 |
194 | if patience > args.patience:
195 | break
196 | tb.close()
197 |
--------------------------------------------------------------------------------
/main_horder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import sys
6 |
7 | from ogb.linkproppred import Evaluator
8 | from torch.utils.data import DataLoader
9 |
10 | from dataloaders import *
11 | from log import *
12 | from models.model_horder import HONet
13 | from train import *
14 |
15 | if __name__ == '__main__':
16 | parser = argparse.ArgumentParser('Interface for SUREL (Higher-Order Prediction)')
17 |
18 | # general model and training setting
19 | parser.add_argument('--dataset', type=str, default='DBLP-coauthor', help='dataset name',
20 | choices=['DBLP-coauthor', 'tags-math'])
21 | parser.add_argument('--model', type=str, default='RNN', help='base model to use',
22 | choices=['RNN', 'MLP', 'Transformer', 'GNN'])
23 | parser.add_argument('--layers', type=int, default=2, help='number of layers')
24 | parser.add_argument('--hidden_dim', type=int, default=64, help='hidden dimension')
25 | parser.add_argument('--x_dim', type=int, default=0, help='dim of raw node features')
26 | parser.add_argument('--data_usage', type=float, default=1.0, help='use partial dataset')
27 | parser.add_argument('--train_ratio', type=float, default=0.6, help='mask partial edges for training')
28 | parser.add_argument('--valid_ratio', type=float, default=0.1, help='use partial valid set')
29 | parser.add_argument('--test_ratio', type=float, default=1.0, help='use partial test set')
30 | parser.add_argument('--metric', type=str, default='mrr', help='metric for evaluating performance',
31 | choices=['auc', 'mrr', 'hit'])
32 | parser.add_argument('--seed', type=int, default=0, help='seed to initialize all the random modules')
33 | parser.add_argument('--gpu_id', type=int, default=1, help='gpu id')
34 | parser.add_argument('--nthread', type=int, default=16, help='number of thread')
35 |
36 | # features and positional encoding
37 | parser.add_argument('--B_size', type=int, default=1500, help='set size of train sampling')
38 | parser.add_argument('--num_walk', type=int, default=100, help='total number of random walks')
39 | parser.add_argument('--num_step', type=int, default=3, help='total steps of random walk')
40 | parser.add_argument('--k', type=int, default=10, help='number of paired negative edges')
41 | parser.add_argument('--directed', type=bool, default=False, help='whether to treat the graph as directed')
42 | parser.add_argument('--use_feature', action='store_true', help='whether to use raw features as input')
43 | parser.add_argument('--use_weight', action='store_true', help='whether to use edge weight as input')
44 | parser.add_argument('--use_degree', action='store_true', help='whether to use node degree as input')
45 | parser.add_argument('--use_val', action='store_true', help='whether to use val as input')
46 | parser.add_argument('--norm', type=str, default='all', help='method of normalization')
47 |
48 | # model training
49 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use')
50 | parser.add_argument('--rtest', type=int, default=499, help='step start to test')
51 | parser.add_argument('--eval_steps', type=int, default=100, help='number of steps to test')
52 | parser.add_argument('--batch_size', type=int, default=32, help='mini-batch size (train)')
53 | parser.add_argument('--batch_num', type=int, default=2000, help='mini-batch size (test)')
54 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
55 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')
56 | parser.add_argument('--l2', type=float, default=0., help='l2 regularization (weight decay)')
57 | parser.add_argument('--patience', type=int, default=3, help='early stopping steps')
58 | parser.add_argument('--repeat', type=int, default=5, help='number of training instances to repeat')
59 |
60 | # logging & debug
61 | parser.add_argument('--log_dir', type=str, default='./log/', help='log directory')
62 | parser.add_argument('--res_dir', type=str, default='./dataset/save', help='resource directory')
63 | parser.add_argument('--stamp', type=str, default='', help='time stamp')
64 | parser.add_argument('--summary_file', type=str, default='result_summary.log',
65 | help='brief summary of training results')
66 | parser.add_argument('--debug', default=False, action='store_true', help='whether to use debug mode')
67 | parser.add_argument('--save', default=False, action='store_true', help='whether to save RPE to files')
68 | parser.add_argument('--load_model', default=False, action='store_true',
69 | help='whether to load saved model from files')
70 | parser.add_argument('--memo', type=str, help='notes')
71 |
72 | sys_argv = sys.argv
73 | try:
74 | args = parser.parse_args()
75 | except:
76 | parser.print_help()
77 | sys.exit(0)
78 |
79 | # setup logger and tensorboard
80 | logger = set_up_log(args, sys_argv)
81 | if args.nthread > 0:
82 | torch.set_num_threads(args.nthread)
83 | logger.info(f"torch num_threads {torch.get_num_threads()}")
84 |
85 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')
86 | prefix = f'{args.res_dir}/model/{args.dataset}/{args.stamp}_{args.num_step}_{args.num_walk}'
87 | g_class = DE_Hyper_Dataset(args.dataset)
88 | G_enc = g_class.process(logger)
89 |
90 | # define model and optim
91 | model = HONet(num_layers=args.layers, input_dim=args.num_step, hidden_dim=args.hidden_dim, out_dim=1,
92 | num_walk=args.num_walk, x_dim=args.x_dim, dropout=args.dropout)
93 | model.to(device)
94 |
95 | if args.optim == 'adam':
96 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
97 | else:
98 | raise NotImplementedError
99 |
100 | evaluator = Evaluator(name='ogbl-citation2')
101 |
102 | if args.load_model:
103 | load_checkpoint(model, optimizer, filename=prefix)
104 |
105 | inf_dict, inf_set = gen_dataset_hyper(g_class, G_enc, args)
106 |
107 | logger.info(
108 | f'Samples: valid {inf_set["val"]["num_pos"]} by {inf_set["val"]["num_neg"]} '
109 | f'test {inf_set["test"]["num_pos"]} by {inf_set["test"]["num_neg"]} metric: {args.metric}')
110 |
111 | rw_dict = {}
112 | triplets = g_class.split_edge['train']['hedge']
113 | num_pos = len(set(G_enc.indices))
114 | loader = DataLoader(range(len(triplets)), args.batch_size, shuffle=True)
115 |
116 | num_batch = 0
117 | for r in range(1, args.repeat + 1):
118 | res_dict = {'test_AUC': [], 'val_AUC': [], f'test_{args.metric}': [], f'val_{args.metric}': []}
119 | model.reset_parameters()
120 | logger.info(f'Running Round {r}')
121 | batchIdx, patience = 0, 0
122 | for perm in loader:
123 | batchIdx += 1
124 | batch = triplets[perm]
125 | B_pos = np.unique(batch)
126 | B_w = [b for b in B_pos if b not in rw_dict]
127 | if len(B_w) > 0:
128 | walk_set, freqs = run_walk(G_enc.indptr, G_enc.indices, B_w, num_walks=args.num_walk,
129 | num_steps=args.num_step - 1, replacement=True)
130 | node_id, node_freq = freqs[:, 0], freqs[:, 1]
131 | rw_dict.update(dict(zip(B_w, zip(walk_set, node_id, node_freq))))
132 |
133 | # obtain set of walks, node id and DE (counts) from the dictionary
134 | W, S, F = zip(*itemgetter(*B_pos)(rw_dict))
135 | data = gen_tuple(np.asarray(W), B_pos, S, batch, args)
136 | F = np.concatenate(F)
137 | mF = torch.from_numpy(np.concatenate([[[0] * F.shape[-1]], F])).to(device)
138 | gT = normalization(mF, args)
139 | loss, auc = train(model, optimizer, data, gT)
140 | logger.info(f'Batch {batchIdx}\tW{len(rw_dict)}/D{num_pos}\tLoss: {loss:.4f}, AUC: {auc:.4f}')
141 |
142 | if batchIdx > args.rtest and batchIdx % args.eval_steps == 0:
143 | bvtime = time.time()
144 | out = eval_model_horder(model, inf_dict, inf_set, args, evaluator, device, mode='val')
145 | if log_record(logger, None, out, res_dict, bvtime, batchIdx):
146 | patience = 0
147 | bttime = time.time()
148 | out = eval_model_horder(model, inf_dict, inf_set, args, evaluator, device, mode='test')
149 | if log_record(logger, None, out, res_dict, bttime, batchIdx):
150 | checkpoint = {'state_dict': model.state_dict(),
151 | 'optimizer': optimizer.state_dict(),
152 | 'epoch': batchIdx}
153 | save_checkpoint(checkpoint, filename=prefix)
154 | else:
155 | patience += 1
156 | if patience > args.patience:
157 | break
158 |
--------------------------------------------------------------------------------
/models/layer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | # MLP with linear output
10 | class MLP(nn.Module):
11 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
12 | """
13 | num_layers: number of layers in the neural networks (EXCLUDING the input layer).
14 | If num_layers=1, this reduces to linear model.
15 | input_dim: dimensionality of input features
16 | hidden_dim: dimensionality of hidden units at ALL layers
17 | output_dim: number of classes for prediction
18 | device: which device to use
19 | """
20 |
21 | super(MLP, self).__init__()
22 |
23 | self.linear_or_not = True # default is linear model
24 | self.num_layers = num_layers
25 |
26 | if num_layers < 1:
27 | raise ValueError("number of layers should be positive!")
28 | elif num_layers == 1:
29 | # Linear model
30 | self.linear = nn.Linear(input_dim, output_dim)
31 | else:
32 | # Multi-layer model
33 | self.linear_or_not = False
34 | self.linears = torch.nn.ModuleList()
35 | self.batch_norms = torch.nn.ModuleList()
36 |
37 | self.linears.append(nn.Linear(input_dim, hidden_dim))
38 | for layer in range(num_layers - 2):
39 | self.linears.append(nn.Linear(hidden_dim, hidden_dim))
40 | self.linears.append(nn.Linear(hidden_dim, output_dim))
41 |
42 | for layer in range(num_layers - 1):
43 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
44 |
45 | def forward(self, x):
46 | if self.linear_or_not:
47 | # If linear model
48 | return self.linear(x)
49 | else:
50 | # If MLP
51 | h = x
52 | for layer in range(self.num_layers - 1):
53 | h = F.relu(self.batch_norms[layer](self.linears[layer](h)))
54 | return self.linears[self.num_layers - 1](h)
55 |
56 | def reset_parameters(self):
57 | # reset parameters for retraining
58 | if self.num_layers == 1:
59 | self.linear.reset_parameters()
60 | else:
61 | # rest linear layers
62 | for linear in self.linears:
63 | linear.reset_parameters()
64 | # rest normalization layers
65 | for norm in self.batch_norms:
66 | norm.reset_parameters()
67 |
68 |
69 | class RNN(nn.Module):
70 | def __init__(self, num_layers, input_dim, hidden_dim, out_dim, dropout=0.0, mtype='lstm'):
71 | super(RNN, self).__init__()
72 | self.hidden_dim = hidden_dim
73 | self.num_layers = num_layers
74 | self.rnn = nn.LSTM(input_size=input_dim,
75 | hidden_size=hidden_dim,
76 | num_layers=num_layers,
77 | batch_first=True)
78 | self.rnn_type = mtype
79 | self.dropout = dropout
80 |
81 | def forward(self, x, walks):
82 | out, _ = self.rnn(x)
83 | enc = out.select(dim=1, index=-1).view(-1, walks, self.hidden_dim)
84 | enc = F.dropout(enc, p=self.dropout, training=self.training)
85 | enc_agg = torch.mean(enc, dim=1)
86 | return enc_agg
87 |
88 | def init_hidden(self, batch_size):
89 | if self.rnn_type == 'gru':
90 | return torch.zeros(self.num_layers * self.directions_count, batch_size, self.hidden_dim).to(self.device)
91 | elif self.rnn_type == 'lstm':
92 | return (
93 | torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device),
94 | torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device))
95 | elif self.rnn_type == 'rnn':
96 | return torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
97 | else:
98 | raise Exception('Unknown rnn_type. Valid options: "gru", "lstm", or "rnn"')
99 |
100 | def reset_parameters(self):
101 | # reset parameters for retraining
102 | self.rnn.reset_parameters()
103 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | from models.layer import MLP, RNN
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 |
9 |
10 | class MergeLayer(torch.nn.Module):
11 | def __init__(self, dim1, dim2, dim3, dim4, non_linear=True, dropout=0.5):
12 | super().__init__()
13 | self.fc1 = nn.Linear(dim1 + dim2, dim3)
14 | self.fc2 = nn.Linear(dim3, dim4)
15 | self.act = nn.ReLU()
16 | self.dropout = dropout
17 |
18 | nn.init.xavier_normal_(self.fc1.weight)
19 | nn.init.xavier_normal_(self.fc2.weight)
20 |
21 | self.non_linear = non_linear
22 | if not non_linear:
23 | assert (dim1 == dim2)
24 | self.fc = nn.Linear(dim1, 1)
25 | nn.init.xavier_normal_(self.fc1.weight)
26 |
27 | def forward(self, x1, x2):
28 | z_walk = None
29 | if self.non_linear:
30 | x = torch.cat([x1, x2], dim=-1)
31 | h = self.act(self.fc1(x))
32 | h = F.dropout(h, p=self.dropout, training=self.training)
33 | z = self.fc2(h)
34 | else:
35 | # x1, x2 shape: [B, M, F]
36 | x = torch.cat([x1, x2], dim=-2) # x shape: [B, 2M, F]
37 | z_walk = self.fc(x).squeeze(-1) # z_walk shape: [B, 2M]
38 | z = z_walk.sum(dim=-1, keepdim=True) # z shape [B, 1]
39 | return z, z_walk
40 |
41 | def reset_parameter(self):
42 | self.fc1.reset_parameters()
43 | self.fc2.reset_parameters()
44 | nn.init.xavier_normal_(self.fc1.weight)
45 | nn.init.xavier_normal_(self.fc2.weight)
46 |
47 |
48 | class Net(torch.nn.Module):
49 | def __init__(self, num_layers, input_dim, hidden_dim, out_dim, num_walk, x_dim=0, dropout=0.5,
50 | use_feature=False, use_weight=False, use_degree=False, use_htype=False):
51 | super(Net, self).__init__()
52 | self.use_feature = use_feature
53 | self.use_degree = use_degree
54 | self.use_htype = use_htype
55 | self.dropout = dropout
56 | self.x_dim = x_dim
57 | self.enc = 'LP' # landing prob at [0, 1, ... num_layers]
58 |
59 | add_dim = 1 if use_weight else 0
60 | self.trainable_embedding = nn.Sequential(nn.Linear(in_features=input_dim + add_dim, out_features=hidden_dim),
61 | nn.ReLU(), nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
62 | print("Relative Positional Encoding: {}".format(self.enc))
63 | if use_feature:
64 | self.rnn = RNN(num_layers, hidden_dim * 2, hidden_dim, out_dim)
65 | elif use_htype:
66 | self.rnn = RNN(num_layers, hidden_dim + x_dim, hidden_dim, out_dim)
67 | else:
68 | self.rnn = RNN(num_layers, hidden_dim, hidden_dim, out_dim)
69 | if use_htype:
70 | self.ntype_embedding = nn.Sequential(nn.Linear(in_features=x_dim, out_features=hidden_dim), nn.ReLU(),
71 | nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
72 | self.affinity_score = MergeLayer(hidden_dim, hidden_dim, hidden_dim, 1, non_linear=True, dropout=dropout)
73 | self.concat_norm = nn.LayerNorm(hidden_dim * 2)
74 | self.len_step = input_dim
75 | self.walks = num_walk
76 |
77 | def forward(self, x, feature=None, debugs=None):
78 | # out shape [2 (u,v), batch*num_walk, 2 (l,r), pos_dim]
79 | x = self.trainable_embedding(x).sum(dim=-2)
80 |
81 | if self.use_degree:
82 | deg = torch.cat(feature[-1]).to(x.device)
83 | x = x / deg
84 | elif self.use_feature:
85 | x = torch.cat([x, feature[0].to(x.device)], dim=-1)
86 | elif self.use_htype:
87 | ntype = F.one_hot(feature[0], self.x_dim).to(x.device).float()
88 | x = torch.cat([ntype, x], dim=-1)
89 | x = x.view(2, -1, self.len_step, x.shape[-1])
90 | out_i, out_j = self.rnn(x[0], self.walks), self.rnn(x[1], self.walks)
91 | score, _ = self.affinity_score(out_i, out_j)
92 | return score.squeeze(1)
93 |
94 | def reset_parameters(self):
95 | for layer in self.trainable_embedding:
96 | if hasattr(layer, 'reset_parameters'):
97 | layer.reset_parameters()
98 | nn.init.xavier_normal_(layer.weight)
99 | self.rnn.reset_parameters()
100 | self.affinity_score.reset_parameter()
101 |
102 |
103 | class LinkPredictor(torch.nn.Module):
104 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
105 | dropout):
106 | super(LinkPredictor, self).__init__()
107 |
108 | self.lins = torch.nn.ModuleList()
109 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
110 | for _ in range(num_layers - 2):
111 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
112 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels))
113 | self.dropout = dropout
114 |
115 | def reset_parameters(self):
116 | for lin in self.lins:
117 | lin.reset_parameters()
118 |
119 | def forward(self, x_i, x_j):
120 | x = x_i * x_j
121 | for lin in self.lins[:-1]:
122 | x = lin(x)
123 | x = F.relu(x)
124 | x = F.dropout(x, p=self.dropout, training=self.training)
125 | x = self.lins[-1](x)
126 | return x
127 |
--------------------------------------------------------------------------------
/models/model_horder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | from models.layer import MLP, RNN
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 |
9 |
10 | class MergeLayer(torch.nn.Module):
11 | def __init__(self, dim1, dim2, dim3, non_linear=True, dropout=0.5):
12 | super().__init__()
13 | self.fc1 = nn.Linear(dim1 * 4, dim2)
14 | self.fc2 = nn.Linear(dim2, dim3)
15 | self.act = nn.ReLU()
16 | self.dropout = dropout
17 |
18 | nn.init.xavier_normal_(self.fc1.weight)
19 | nn.init.xavier_normal_(self.fc2.weight)
20 |
21 | self.non_linear = non_linear
22 | if not non_linear:
23 | assert (dim1 == dim2 == dim3)
24 | self.fc = nn.Linear(dim1, 1)
25 | nn.init.xavier_normal_(self.fc1.weight)
26 |
27 | def forward(self, x1, x2, x3, x4):
28 | z_walk = None
29 | if self.non_linear:
30 | x = torch.cat([x1, x2, x3, x4], dim=-1)
31 | h = self.act(self.fc1(x))
32 | h = F.dropout(h, p=self.dropout, training=self.training)
33 | z = self.fc2(h)
34 | else:
35 | x = torch.cat([x1, x2, x3, x4], dim=-2) # x shape: [B, 2M, F]
36 | z_walk = self.fc(x).squeeze(-1) # z_walk shape: [B, 2M]
37 | z = z_walk.sum(dim=-1, keepdim=True) # z shape [B, 1]
38 | return z, z_walk
39 |
40 | def reset_parameter(self):
41 | self.fc1.reset_parameters()
42 | self.fc2.reset_parameters()
43 |
44 |
45 | class HONet(torch.nn.Module):
46 | def __init__(self, num_layers, input_dim, hidden_dim, out_dim, num_walk, x_dim=0, dropout=0.5):
47 | super(HONet, self).__init__()
48 | self.dropout = dropout
49 | self.x_dim = x_dim
50 | self.enc = 'LP' # landing prob at [0, 1, ... num_layers]
51 |
52 | self.trainable_embedding = nn.Sequential(nn.Linear(in_features=input_dim, out_features=hidden_dim),
53 | nn.ReLU(), nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
54 |
55 | print("Relative Positional Encoding: {}".format(self.enc))
56 | self.rnn = RNN(num_layers, hidden_dim, hidden_dim, out_dim)
57 | self.affinity_score = MergeLayer(hidden_dim, hidden_dim, 1, non_linear=True, dropout=dropout)
58 | self.concat_norm = nn.LayerNorm(hidden_dim * 2)
59 | self.len_step = input_dim
60 | self.walks = num_walk
61 |
62 | def forward(self, x):
63 | x = self.trainable_embedding(x).sum(dim=-2)
64 | x = x.view(4, -1, self.len_step, x.shape[-1])
65 | wu, wv, uw, vw = self.rnn(x[0], self.walks), self.rnn(x[1], self.walks), self.rnn(x[2], self.walks), self.rnn(
66 | x[3], self.walks)
67 | score, _ = self.affinity_score(wu, wv, uw, vw)
68 | return score.squeeze(1)
69 |
70 | def reset_parameters(self):
71 | for layer in self.trainable_embedding:
72 | if hasattr(layer, 'reset_parameters'):
73 | layer.reset_parameters()
74 | nn.init.xavier_normal_(layer.weight)
75 | self.rnn.reset_parameters()
76 | self.affinity_score.reset_parameter()
77 |
--------------------------------------------------------------------------------
/subg_acc/.gitignore:
--------------------------------------------------------------------------------
1 | # Source: https://github.com/github/gitignore/blob/main/C.gitignore
2 | # Prerequisites
3 | *.d
4 |
5 | # Object files
6 | *.o
7 | *.ko
8 | *.obj
9 | *.elf
10 |
11 | # Linker output
12 | *.ilk
13 | *.map
14 | *.exp
15 |
16 | # Precompiled Headers
17 | *.gch
18 | *.pch
19 |
20 | # Libraries
21 | *.lib
22 | *.a
23 | *.la
24 | *.lo
25 |
26 | # Shared objects (inc. Windows DLLs)
27 | *.dll
28 | *.so
29 | *.so.*
30 | *.dylib
31 |
32 | # Executables
33 | *.exe
34 | *.out
35 | *.app
36 | *.i*86
37 | *.x86_64
38 | *.hex
39 |
40 | # Debug files
41 | *.dSYM/
42 | *.su
43 | *.idb
44 | *.pdb
45 |
46 | # Kernel Module Compile Results
47 | *.mod*
48 | *.cmd
49 | .tmp_versions/
50 | modules.order
51 | Module.symvers
52 | Mkfile.old
53 | dkms.conf
54 |
--------------------------------------------------------------------------------
/subg_acc/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2021-Present, Haoteng Yin
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/subg_acc/README.md:
--------------------------------------------------------------------------------
1 | # Subgraph Operation Accelerator
2 |
3 |
4 |
5 |
6 |
7 |
8 | The `subg_acc` package is an extension library based on C and openmp to accelerate subgraph operations in subgraph-based graph representation learning (SGRL) with multithreading enabled. Follow the principles of algorithm system co-design in [SUREL](https://arxiv.org/abs/2202.13538)/[SUREL+](https://github.com/VeritasYin/SUREL_Plus/blob/main/manuscript/SUREL_Plus_Full.pdf), query-level subgraphs (of link/motif) (e.g. ego-network in canonical SGRLs) are decomposed into reusable node-level ones. Currently, `subg_acc` consists of the following methods for the realization of scalable SGRLs:
9 |
10 | - `run_walk` walk-based subgraph sampling
11 | - `run_sample` walk-based sampling of training batches
12 | - `rpe_encoder` relative positional encoding (localized structural feature construction)
13 | - `sjoin` online subgraph joining that rebuilds the query-level subgraph from node-level ones to serve queries (a set of nodes)
14 |
15 | ## Requirements
16 | (Other versions may work, but are untested)
17 |
18 | - python >= 3.8
19 | - gcc >= 8.4
20 | - cmake >= 3.16
21 | - make >= 4.2
22 |
23 | ## Installation
24 | ```
25 | python setup.py install
26 | ```
27 |
28 |
--------------------------------------------------------------------------------
/subg_acc/graph_acc.c:
--------------------------------------------------------------------------------
1 | #define PY_SSIZE_T_CLEAN
2 |
3 | #include
4 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
5 | #include
6 | #include
7 | #include "uthash.h"
8 |
9 | #define DEBUG 0
10 |
11 | typedef struct item_int {
12 | int key;
13 | int val;
14 | UT_hash_handle hh;
15 | } dict_int;
16 |
17 | static int find_key_int(dict_int *maps, int key) {
18 | dict_int *s;
19 | HASH_FIND_INT(maps, &key, s); /* s: output pointer */
20 | return s ? s->val : -1;
21 | }
22 |
23 | void add_item(dict_int **maps, int key) {
24 | dict_int *s;
25 | HASH_FIND_INT(*maps, &key, s); /* s: output pointer */
26 | if (s == NULL) {
27 | dict_int *k = malloc(sizeof(*k));
28 | // walk starts from each node (main key)
29 | k->key = key;
30 | HASH_ADD_INT(*maps, key, k);
31 | }
32 | }
33 |
34 | /* hash of hashes */
35 | typedef struct item {
36 | int key;
37 | int val;
38 | struct item *sub;
39 | UT_hash_handle hh;
40 | } dict_item;
41 |
42 | static int find_key_item(dict_item *items, int key) {
43 | dict_item *s;
44 | HASH_FIND_INT(items, &key, s); /* s: output pointer */
45 | return s ? s->val : -1;
46 | }
47 |
48 | static int find_idx(dict_item *items, int id1, int id2) {
49 | dict_item *s, *p;
50 | HASH_FIND_INT(items, &id1, s); /* s: output pointer */
51 | if (s != NULL) {
52 | HASH_FIND_INT(s->sub, &id2, p);
53 | return p ? p->val : 0;
54 | } else {
55 | return -1;
56 | }
57 | }
58 |
59 | void delete_all(dict_item *maps) {
60 | dict_item *item1, *item2, *tmp1, *tmp2;
61 |
62 | /* clean up both hash tables */
63 | HASH_ITER(hh, maps, item1, tmp1) {
64 | HASH_ITER(hh, item1->sub, item2, tmp2) {
65 | HASH_DEL(item1->sub, item2);
66 | free(item2);
67 | }
68 | HASH_DEL(maps, item1);
69 | free(item1);
70 | }
71 | }
72 |
73 | // test func
74 | static PyObject *adds(PyObject *self, PyObject *args) {
75 | int arg1, arg2;
76 | if (!(PyArg_ParseTuple(args, "ii", &arg1, &arg2))) {
77 | return NULL;
78 | }
79 | return Py_BuildValue("i", arg1 * 2 + arg2 * 7);
80 | }
81 |
82 | static PyObject *exe(PyObject *self, PyObject *args) {
83 | const char *command;
84 | int sts;
85 |
86 | if (!PyArg_ParseTuple(args, "s", &command))
87 | return NULL;
88 | sts = system(command);
89 | return PyLong_FromLong(sts);
90 | }
91 |
92 | static void f_format(const npy_intp *dims, int *CArrays) {
93 | for (int x = 0; x < dims[0]; x++) {
94 | printf("idx %d: \n", x);
95 | for (int y = 0; y < dims[1]; y++) {
96 | printf("%d ", CArrays[x * dims[1] + y]);
97 | }
98 | printf("\n");
99 | }
100 | }
101 |
102 | static void
103 | random_walk(int const *ptr, int const *neighs, int const *seq, int n, int num_walks, int num_steps, int seed,
104 | int nthread, int *walks) {
105 | /* https://github.com/lkskstlr/rwalk */
106 | if (DEBUG) {
107 | printf("get in with n: %d, num_walks: %d, num_steps: %d, seed: %d, nthread: %d\n", n, num_walks, num_steps,
108 | seed, nthread);
109 | }
110 | if (nthread > 0) {
111 | omp_set_num_threads(nthread);
112 | }
113 | #pragma omp parallel
114 | {
115 | int thread_num = omp_get_thread_num();
116 | unsigned int private_seed = (unsigned int) (seed + thread_num);
117 | #pragma omp for
118 | for (int i = 0; i < n; i++) {
119 | int offset, num_neighs;
120 | for (int walk = 0; walk < num_walks; walk++) {
121 | int curr = seq[i];
122 | offset = i * num_walks * (num_steps + 1) + walk * (num_steps + 1);
123 | walks[offset] = curr;
124 | for (int step = 0; step < num_steps; step++) {
125 | num_neighs = ptr[curr + 1] - ptr[curr];
126 | if (num_neighs > 0) {
127 | curr = neighs[ptr[curr] + (rand_r(&private_seed) % num_neighs)];
128 | }
129 | walks[offset + step + 1] = curr;
130 | }
131 | }
132 | }
133 | }
134 | }
135 |
136 | // random walk without replacement (1st neigh)
137 | static void
138 | random_walk_wo(int const *ptr, int const *neighs, int const *seq, int n, int num_walks, int num_steps, int seed,
139 | int nthread, int *walks) {
140 | if (nthread > 0) {
141 | omp_set_num_threads(nthread);
142 | }
143 | #pragma omp parallel
144 | {
145 | int thread_num = omp_get_thread_num();
146 | unsigned int private_seed = (unsigned int) (seed + thread_num);
147 |
148 | #pragma omp for
149 | for (int i = 0; i < n; i++) {
150 | int offset, num_neighs;
151 |
152 | int num_hop1 = ptr[seq[i] + 1] - ptr[seq[i]];
153 | int rseq[num_hop1];
154 | if (num_hop1 > num_walks) {
155 | // https://www.programmersought.com/article/71554044511/
156 | int s, t;
157 | for (int j = 0; j < num_hop1; j++)
158 | rseq[j] = j;
159 | for (int k = 0; k < num_walks; k++) {
160 | s = rand_r(&private_seed) % (num_hop1 - k) + k;
161 | t = rseq[k];
162 | rseq[k] = rseq[s];
163 | rseq[s] = t;
164 | }
165 | }
166 |
167 | for (int walk = 0; walk < num_walks; walk++) {
168 | int curr = seq[i];
169 | offset = i * num_walks * (num_steps + 1) + walk * (num_steps + 1);
170 | walks[offset] = curr;
171 | if (num_hop1 < 1){
172 | walks[offset + 1] = curr;
173 | }
174 | else if (num_hop1 <= num_walks) {
175 | curr = neighs[ptr[curr] + walk % num_hop1];
176 | walks[offset + 1] = curr;
177 | } else {
178 | curr = neighs[ptr[curr] + rseq[walk]];
179 | walks[offset + 1] = curr;
180 | }
181 | for (int step = 1; step < num_steps; step++) {
182 | num_neighs = ptr[curr + 1] - ptr[curr];
183 | if (num_neighs > 0) {
184 | curr = neighs[ptr[curr] + (rand_r(&private_seed) % num_neighs)];
185 | }
186 | walks[offset + step + 1] = curr;
187 | }
188 | }
189 | }
190 | }
191 | }
192 |
193 |
194 | void rpe_encoder(int const *arr, int idx, int num_walks, int num_steps, PyArrayObject **out) {
195 | PyArrayObject *oarr1 = NULL, *oarr2 = NULL;
196 | dict_int *mapping = NULL;
197 | int offset = idx * num_walks * (num_steps + 1);
198 |
199 | // setup root node
200 | dict_int *root = malloc(sizeof(*root));
201 | root->key = arr[offset];
202 | root->val = 0;
203 | HASH_ADD_INT(mapping, key, root);
204 |
205 | // setup the rest unique node
206 | int count = 1;
207 | for (int i = 1; i < num_steps + 1; i++) {
208 | for (int j = 0; j < num_walks; j++) {
209 | int token = arr[offset + j * (num_steps + 1) + i];
210 | if (find_key_int(mapping, token) < 0) {
211 | dict_int *k = malloc(sizeof(*k));
212 | // walk starts from each node (main key)
213 | k->key = token;
214 | k->val = count;
215 | HASH_ADD_INT(mapping, key, k);
216 | count++;
217 | }
218 | }
219 | }
220 | int num_keys = HASH_COUNT(mapping);
221 |
222 | // create a new array
223 | npy_intp odims1[2] = {num_keys, num_steps + 1};
224 | oarr1 = (PyArrayObject *) PyArray_ZEROS(2, odims1, NPY_INT, 0);
225 | if (!oarr1) PyErr_SetString(PyExc_TypeError, "output error.");
226 | int *Coarr1 = (int *) PyArray_DATA(oarr1);
227 |
228 | npy_intp odims2[1] = {num_keys};
229 | oarr2 = (PyArrayObject *) PyArray_SimpleNew(1, odims2, NPY_INT);
230 | if (!oarr2) PyErr_SetString(PyExc_TypeError, "output error.");
231 | int *Coarr2 = (int *) PyArray_DATA(oarr2);
232 |
233 | Coarr1[0] = num_walks;
234 |
235 | for (int i = 1; i < num_steps + 1; i++) {
236 | for (int j = 0; j < num_walks; j++) {
237 | int anchor = find_key_int(mapping, arr[offset + j * (num_steps + 1) + i]);
238 | Coarr1[anchor * (num_steps + 1) + i]++;
239 | }
240 | }
241 |
242 | // free mem
243 | dict_int *cur_item, *tmp;
244 | HASH_ITER(hh, mapping, cur_item, tmp) {
245 | Coarr2[cur_item->val] = cur_item->key;
246 | HASH_DEL(mapping, cur_item); /* delete it (users advances to next) */
247 | free(cur_item); /* free it */
248 | }
249 | out[2 * idx] = oarr2, out[2 * idx + 1] = oarr1;
250 | }
251 |
252 | static PyObject *np_sample(PyObject *self, PyObject *args, PyObject *kws) {
253 | PyObject *arg1 = NULL, *arg2 = NULL, *query = NULL;
254 | PyArrayObject *ptr = NULL, *neighs = NULL, *seq = NULL, *oarr = NULL;
255 | int num_walks = 200, num_steps = 8, seed = 111413, nthread = -1, thld = 1000;
256 |
257 | static char *kwlist[] = {"ptr", "neighs", "query", "num_walks", "num_steps", "thld", "nthread", "seed", NULL};
258 | if (!(PyArg_ParseTupleAndKeywords(args, kws, "OOO|iiiiip", kwlist, &arg1, &arg2, &query, &num_walks, &num_steps,
259 | &thld, &nthread, &seed))) {
260 | PyErr_SetString(PyExc_TypeError, "input parsing error.");
261 | return NULL;
262 | }
263 |
264 | /* handle walks (numpy array) */
265 | ptr = (PyArrayObject *) PyArray_FROM_OTF(arg1, NPY_INT, NPY_ARRAY_IN_ARRAY);
266 | if (!ptr) return NULL;
267 | int *Cptr = PyArray_DATA(ptr);
268 |
269 | neighs = (PyArrayObject *) PyArray_FROM_OTF(arg2, NPY_INT, NPY_ARRAY_IN_ARRAY);
270 | if (!neighs) return NULL;
271 | int *Cneighs = PyArray_DATA(neighs);
272 |
273 | seq = (PyArrayObject *) PyArray_FROM_OTF(query, NPY_INT, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_FORCECAST);
274 | if (!seq) return NULL;
275 | int *Cseq = PyArray_DATA(seq);
276 |
277 | int n = (int) PyArray_SIZE(seq);
278 |
279 | unsigned int private_seed = (unsigned int) (seed + getpid());
280 | /* initialize the hashtable */
281 | dict_int *sets = NULL;
282 | for (int i = 0; i < n; i++) {
283 | int num_hop1 = Cptr[Cseq[i] + 1] - Cptr[Cseq[i]];
284 | int num_neighs, rseq[num_hop1];
285 | if (num_hop1 > num_walks) {
286 | int s, t;
287 | for (int j = 0; j < num_hop1; j++)
288 | rseq[j] = j;
289 | for (int k = 0; k < num_walks; k++) {
290 | s = rand_r(&private_seed) % (num_hop1 - k) + k;
291 | t = rseq[k];
292 | rseq[k] = rseq[s];
293 | rseq[s] = t;
294 | }
295 | }
296 | add_item(&sets, Cseq[i]);
297 |
298 | for (int walk = 0; walk < num_walks; walk++) {
299 | int curr = Cseq[i];
300 | if (num_hop1 < 1) {
301 | break;
302 | } else if (num_hop1 <= num_walks) {
303 | curr = Cneighs[Cptr[curr] + walk % num_hop1];
304 | } else {
305 | curr = Cneighs[Cptr[curr] + rseq[walk]];
306 | }
307 | for (int step = 1; step < num_steps; step++) {
308 | add_item(&sets, curr);
309 | num_neighs = Cptr[curr + 1] - Cptr[curr];
310 | if (num_neighs > 0) {
311 | curr = Cneighs[Cptr[curr] + (rand_r(&private_seed) % num_neighs)];
312 | add_item(&sets, curr);
313 | }
314 | }
315 | if ((int) HASH_COUNT(sets) >= ((i + 1) * thld / n))
316 | break;
317 | }
318 | }
319 |
320 | npy_intp odims[1] = {HASH_COUNT(sets)};
321 | oarr = (PyArrayObject *) PyArray_SimpleNew(1, odims, NPY_INT);
322 | if (oarr == NULL) goto fail;
323 | int *Coarr = (int *) PyArray_DATA(oarr);
324 |
325 | // free memory
326 | dict_int *cur_item, *tmp;
327 | int idx = 0;
328 | HASH_ITER(hh, sets, cur_item, tmp) {
329 | Coarr[idx] = cur_item->key;
330 | HASH_DEL(sets, cur_item); /* delete it (users advances to next) */
331 | free(cur_item); /* free it */
332 | idx++;
333 | }
334 |
335 | Py_DECREF(ptr);
336 | Py_DECREF(neighs);
337 | Py_DECREF(seq);
338 | return PyArray_Return(oarr);
339 |
340 | fail:
341 | Py_XDECREF(ptr);
342 | Py_XDECREF(neighs);
343 | Py_XDECREF(seq);
344 | PyArray_DiscardWritebackIfCopy(oarr);
345 | PyArray_XDECREF(oarr);
346 | return NULL;
347 | }
348 |
349 | static PyObject *np_walk(PyObject *self, PyObject *args, PyObject *kws) {
350 | PyObject *arg1 = NULL, *arg2 = NULL, *query = NULL;
351 | PyArrayObject *ptr = NULL, *neighs = NULL, *seq = NULL, *oarr = NULL, *obj_arr = NULL;
352 | int num_walks = 100, num_steps = 3, seed = 111413, nthread = -1, re = -1;
353 | int n;
354 |
355 | static char *kwlist[] = {"ptr", "neighs", "query", "num_walks", "num_steps", "nthread", "seed", "replacement", NULL};
356 | if (!(PyArg_ParseTupleAndKeywords(args, kws, "OOO|iiiip", kwlist, &arg1, &arg2, &query, &num_walks, &num_steps,
357 | &nthread, &seed, &re))) {
358 | PyErr_SetString(PyExc_TypeError, "input parsing error.");
359 | return NULL;
360 | }
361 |
362 | /* handle walks (numpy array) */
363 | ptr = (PyArrayObject *) PyArray_FROM_OTF(arg1, NPY_INT, NPY_ARRAY_IN_ARRAY);
364 | if (!ptr) return NULL;
365 | int *Cptr = PyArray_DATA(ptr);
366 |
367 | neighs = (PyArrayObject *) PyArray_FROM_OTF(arg2, NPY_INT, NPY_ARRAY_IN_ARRAY);
368 | if (!neighs) return NULL;
369 | int *Cneighs = PyArray_DATA(neighs);
370 |
371 | seq = (PyArrayObject *) PyArray_FROM_OTF(query, NPY_INT, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_FORCECAST);
372 | if (!seq) return NULL;
373 | int *Cseq = PyArray_DATA(seq);
374 |
375 | n = (int) PyArray_SIZE(seq);
376 |
377 | npy_intp odims[2] = {n, num_walks * (num_steps + 1)};
378 | oarr = (PyArrayObject *) PyArray_SimpleNew(2, odims, NPY_INT);
379 | if (oarr == NULL) goto fail;
380 | int *Coarr = (int *) PyArray_DATA(oarr);
381 |
382 | npy_intp obj_dims[2] = {n, 2};
383 | obj_arr = (PyArrayObject *) PyArray_SimpleNew(2, obj_dims, NPY_OBJECT);
384 | if (obj_arr == NULL) goto fail;
385 | PyArrayObject **Cobj_arr = PyArray_DATA(obj_arr);
386 |
387 | if (re > 0) {
388 | // printf("Using no replacement sampling for the 1-hop.\n");
389 | random_walk_wo(Cptr, Cneighs, Cseq, n, num_walks, num_steps, seed, nthread, Coarr);
390 | } else {
391 | random_walk(Cptr, Cneighs, Cseq, n, num_walks, num_steps, seed, nthread, Coarr);
392 | }
393 |
394 | #pragma omp for
395 | for (int k = 0; k < n; k++) {
396 | rpe_encoder(Coarr, k, num_walks, num_steps, Cobj_arr);
397 | }
398 |
399 | Py_DECREF(ptr);
400 | Py_DECREF(neighs);
401 | Py_DECREF(seq);
402 | return Py_BuildValue("[N,N]", PyArray_Return(oarr), PyArray_Return(obj_arr));
403 |
404 | fail:
405 | Py_XDECREF(ptr);
406 | Py_XDECREF(neighs);
407 | Py_XDECREF(seq);
408 | PyArray_DiscardWritebackIfCopy(oarr);
409 | PyArray_XDECREF(oarr);
410 | PyArray_DiscardWritebackIfCopy(obj_arr);
411 | PyArray_XDECREF(obj_arr);
412 | return NULL;
413 | }
414 |
415 | static PyObject *np_join(PyObject *self, PyObject *args, PyObject *kws) {
416 | PyObject *arg1 = NULL, *arg2 = NULL, *query = NULL, *seq = NULL, **src;
417 | PyArrayObject *arr = NULL, *iarr = NULL, *oarr = NULL, *xarr = NULL;
418 | int nthread = -1, re = -1;
419 |
420 | static char *kwlist[] = {"walk", "key", "query", "nthread", "return_idx", NULL};
421 | if (!(PyArg_ParseTupleAndKeywords(args, kws, "OOO|ip", kwlist, &arg1, &arg2, &query, &nthread, &re))) {
422 | PyErr_SetString(PyExc_TypeError, "input parsing error.");
423 | return NULL;
424 | }
425 |
426 | /* handle walks (numpy array) */
427 | arr = (PyArrayObject *) PyArray_FROM_OTF(arg1, NPY_INT, NPY_ARRAY_IN_ARRAY);
428 | if (!arr) return NULL;
429 | npy_intp *arr_dims = PyArray_DIMS(arr);
430 | int stride;
431 | if (PyArray_NDIM(arr) > 2) {
432 | stride = (int) (arr_dims[1] * arr_dims[2]);
433 | } else {
434 | stride = (int) arr_dims[1];
435 | }
436 | int *Carr = (int *) PyArray_DATA(arr);
437 |
438 | /* handle keys (a list of tuple) */
439 | seq = PySequence_Fast(arg2, "argument must be iterable");
440 | if (!seq) return NULL;
441 | if (PySequence_Fast_GET_SIZE(seq) != arr_dims[0]) {
442 | PyErr_SetString(PyExc_TypeError, "dims do not match between walks and keys.");
443 | return NULL;
444 | }
445 |
446 | /* handle queries (numpy array/sequence) */
447 | iarr = (PyArrayObject *) PyArray_FROM_OTF(query, NPY_INT, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_FORCECAST);
448 | if (!iarr) return NULL;
449 | npy_intp *iarr_dims = PyArray_DIMS(iarr);
450 | int *Ciarr = (int *) PyArray_DATA(iarr);
451 |
452 | if (DEBUG) {
453 | printf("Dims of query: %d, %d\n", (int) iarr_dims[0], (int) iarr_dims[1]);
454 | }
455 |
456 | /* initialize the hashtable */
457 | dict_item *items = NULL;
458 | int idx = 1;
459 | src = PySequence_Fast_ITEMS(seq);
460 |
461 | /* build two level hash table: 1) reindex main keys 2) hash unique node idx associated with each key */
462 | for (int i = 0; i < arr_dims[0]; i++) {
463 | /* make initial element */
464 | dict_item *k = malloc(sizeof(*k));
465 | // walk starts from each node (main key)
466 | k->key = Carr[i * stride];
467 | k->sub = NULL;
468 | k->val = i;
469 | HASH_ADD_INT(items, key, k);
470 |
471 | PyObject *item = PySequence_Fast(src[i], "argument must be iterable");
472 | int item_size;
473 | if (!PyArray_CheckExact(item)) {
474 | item_size = PySequence_Fast_GET_SIZE(item);
475 |
476 | for (int j = 0; j < item_size; j++) {
477 | /* add a sub hash table off this element */
478 | dict_item *w = malloc(sizeof(*w));
479 | w->key = (int) PyLong_AsLong(PySequence_Fast_GET_ITEM(item, j));
480 | w->sub = NULL;
481 | w->val = idx;
482 | HASH_ADD_INT(k->sub, key, w);
483 | idx++;
484 | }
485 | } else {
486 | item_size = PyArray_Size(item);
487 |
488 | for (int j = 0; j < item_size; j++) {
489 | /* add a sub hash table off this element */
490 | dict_item *w = malloc(sizeof(*w));
491 | w->key = (*(int *) PyArray_GETPTR1((PyArrayObject *) item, j));
492 | w->sub = NULL;
493 | w->val = idx;
494 | HASH_ADD_INT(k->sub, key, w);
495 | idx++;
496 | }
497 | }
498 | // must add, to avoid memory leakage
499 | Py_DECREF(item);
500 | }
501 |
502 | /* allocate a new return numpy array */
503 | npy_intp odims[2] = {2, iarr_dims[0] * 2 * stride};
504 | oarr = (PyArrayObject *) PyArray_SimpleNew(2, odims, NPY_INT);
505 | if (oarr == NULL) goto fail;
506 | int *Coarr = (int *) PyArray_DATA(oarr);
507 |
508 | if (DEBUG) {
509 | printf("Dims of output: %d, %d\n", (int) odims[0], (int) odims[1]);
510 | }
511 |
512 | xarr = (PyArrayObject *) PyArray_SimpleNew(2, iarr_dims, NPY_INT);
513 | if (xarr == NULL) goto fail;
514 | int *Cxarr = (int *) PyArray_DATA(xarr);
515 |
516 | if (nthread > 0) {
517 | omp_set_num_threads(nthread);
518 | }
519 |
520 | #pragma omp parallel for
521 | for (int x = 0; x < iarr_dims[0]; x++) {
522 | int qid = 2 * x;
523 | int key1 = Ciarr[qid], key2 = Ciarr[qid + 1];
524 | Cxarr[qid] = find_key_item(items, key1), Cxarr[qid+1] = find_key_item(items, key2);
525 | for (int y = 0; y < 2 * stride; y += 2) {
526 | Coarr[qid * stride + y] = find_idx(items, key1, Carr[Cxarr[qid] * stride + y / 2]);
527 | Coarr[qid * stride + y + 1] = find_idx(items, key2, Carr[Cxarr[qid] * stride + y / 2]);
528 | Coarr[odims[1] + qid * stride + y] = find_idx(items, key1, Carr[Cxarr[qid+1] * stride + y / 2]);
529 | Coarr[odims[1] + qid * stride + y + 1] = find_idx(items, key2, Carr[Cxarr[qid+1] * stride + y / 2]);
530 | }
531 | }
532 |
533 | Py_DECREF(arr);
534 | Py_DECREF(iarr);
535 | Py_DECREF(seq);
536 | delete_all(items);
537 | if (re>0){
538 | return Py_BuildValue("[N,N]", PyArray_Return(oarr), PyArray_Return(xarr));
539 | }else{
540 | return PyArray_Return(oarr);
541 | }
542 |
543 | fail:
544 | Py_XDECREF(arr);
545 | Py_XDECREF(iarr);
546 | Py_XDECREF(seq);
547 | delete_all(items);
548 | PyArray_DiscardWritebackIfCopy(oarr);
549 | PyArray_XDECREF(oarr);
550 | PyArray_DiscardWritebackIfCopy(xarr);
551 | PyArray_XDECREF(xarr);
552 | return NULL;
553 | }
554 |
555 | static PyMethodDef GAccMethods[] = {
556 | {"add", adds, METH_VARARGS, "Add ops."},
557 | {"run", exe, METH_VARARGS, "Execute a shell command."},
558 | {"sjoin", (PyCFunction) np_join, METH_VARARGS | METH_KEYWORDS,
559 | "RPE (subgraph) join op with a list of pairs (numpy, openmp)."},
560 | {"run_walk", (PyCFunction) np_walk, METH_VARARGS | METH_KEYWORDS,
561 | "Random walks with RPE encoding (numpy, openmp)."},
562 | {"run_sample", (PyCFunction) np_sample, METH_VARARGS | METH_KEYWORDS,
563 | "Random sampling (numpy, openmp)."},
564 | {NULL, NULL, 0, NULL}
565 | };
566 |
567 | static char gacc_doc[] = "C extension for SUREL framework.";
568 |
569 | static struct PyModuleDef gaccmodule = {
570 | PyModuleDef_HEAD_INIT,
571 | "surel_gacc", /* name of module */
572 | gacc_doc, /* module documentation, may be NULL */
573 | -1, /* size of per-interpreter state of the module,
574 | or -1 if the module keeps state in global variables. */
575 | GAccMethods
576 | };
577 |
578 | PyMODINIT_FUNC PyInit_surel_gacc(void) {
579 | import_array();
580 | return PyModule_Create(&gaccmodule);
581 | }
582 |
--------------------------------------------------------------------------------
/subg_acc/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup, Extension
2 | import numpy
3 |
4 | module1 = Extension('surel_gacc',
5 | sources = ['graph_acc.c'],
6 | extra_compile_args=['-fopenmp'],
7 | extra_link_args=['-lgomp'],
8 | include_dirs=[numpy.get_include()])
9 |
10 | setup (name = 'SUREL_GAcc',
11 | version = '1.1',
12 | description = 'This is a package for accelerated graph operations in SUREL framework.',
13 | ext_modules = [module1])
14 |
--------------------------------------------------------------------------------
/subg_acc/uthash.h:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (c) 2003-2022, Troy D. Hanson https://troydhanson.github.io/uthash/
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 |
11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
12 | IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
13 | TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
14 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
15 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
16 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
17 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
18 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
19 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
20 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
21 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
22 | */
23 |
24 | #ifndef UTHASH_H
25 | #define UTHASH_H
26 |
27 | #define UTHASH_VERSION 2.3.0
28 |
29 | #include /* memcmp, memset, strlen */
30 | #include /* ptrdiff_t */
31 | #include /* exit */
32 |
33 | #if defined(HASH_DEFINE_OWN_STDINT) && HASH_DEFINE_OWN_STDINT
34 | /* This codepath is provided for backward compatibility, but I plan to remove it. */
35 | #warning "HASH_DEFINE_OWN_STDINT is deprecated; please use HASH_NO_STDINT instead"
36 | typedef unsigned int uint32_t;
37 | typedef unsigned char uint8_t;
38 | #elif defined(HASH_NO_STDINT) && HASH_NO_STDINT
39 | #else
40 | #include /* uint8_t, uint32_t */
41 | #endif
42 |
43 | /* These macros use decltype or the earlier __typeof GNU extension.
44 | As decltype is only available in newer compilers (VS2010 or gcc 4.3+
45 | when compiling c++ source) this code uses whatever method is needed
46 | or, for VS2008 where neither is available, uses casting workarounds. */
47 | #if !defined(DECLTYPE) && !defined(NO_DECLTYPE)
48 | #if defined(_MSC_VER) /* MS compiler */
49 | #if _MSC_VER >= 1600 && defined(__cplusplus) /* VS2010 or newer in C++ mode */
50 | #define DECLTYPE(x) (decltype(x))
51 | #else /* VS2008 or older (or VS2010 in C mode) */
52 | #define NO_DECLTYPE
53 | #endif
54 | #elif defined(__MCST__) /* Elbrus C Compiler */
55 | #define DECLTYPE(x) (__typeof(x))
56 | #elif defined(__BORLANDC__) || defined(__ICCARM__) || defined(__LCC__) || defined(__WATCOMC__)
57 | #define NO_DECLTYPE
58 | #else /* GNU, Sun and other compilers */
59 | #define DECLTYPE(x) (__typeof(x))
60 | #endif
61 | #endif
62 |
63 | #ifdef NO_DECLTYPE
64 | #define DECLTYPE(x)
65 | #define DECLTYPE_ASSIGN(dst,src) \
66 | do { \
67 | char **_da_dst = (char**)(&(dst)); \
68 | *_da_dst = (char*)(src); \
69 | } while (0)
70 | #else
71 | #define DECLTYPE_ASSIGN(dst,src) \
72 | do { \
73 | (dst) = DECLTYPE(dst)(src); \
74 | } while (0)
75 | #endif
76 |
77 | #ifndef uthash_malloc
78 | #define uthash_malloc(sz) malloc(sz) /* malloc fcn */
79 | #endif
80 | #ifndef uthash_free
81 | #define uthash_free(ptr,sz) free(ptr) /* free fcn */
82 | #endif
83 | #ifndef uthash_bzero
84 | #define uthash_bzero(a,n) memset(a,'\0',n)
85 | #endif
86 | #ifndef uthash_strlen
87 | #define uthash_strlen(s) strlen(s)
88 | #endif
89 |
90 | #ifndef HASH_FUNCTION
91 | #define HASH_FUNCTION(keyptr,keylen,hashv) HASH_JEN(keyptr, keylen, hashv)
92 | #endif
93 |
94 | #ifndef HASH_KEYCMP
95 | #define HASH_KEYCMP(a,b,n) memcmp(a,b,n)
96 | #endif
97 |
98 | #ifndef uthash_noexpand_fyi
99 | #define uthash_noexpand_fyi(tbl) /* can be defined to log noexpand */
100 | #endif
101 | #ifndef uthash_expand_fyi
102 | #define uthash_expand_fyi(tbl) /* can be defined to log expands */
103 | #endif
104 |
105 | #ifndef HASH_NONFATAL_OOM
106 | #define HASH_NONFATAL_OOM 0
107 | #endif
108 |
109 | #if HASH_NONFATAL_OOM
110 | /* malloc failures can be recovered from */
111 |
112 | #ifndef uthash_nonfatal_oom
113 | #define uthash_nonfatal_oom(obj) do {} while (0) /* non-fatal OOM error */
114 | #endif
115 |
116 | #define HASH_RECORD_OOM(oomed) do { (oomed) = 1; } while (0)
117 | #define IF_HASH_NONFATAL_OOM(x) x
118 |
119 | #else
120 | /* malloc failures result in lost memory, hash tables are unusable */
121 |
122 | #ifndef uthash_fatal
123 | #define uthash_fatal(msg) exit(-1) /* fatal OOM error */
124 | #endif
125 |
126 | #define HASH_RECORD_OOM(oomed) uthash_fatal("out of memory")
127 | #define IF_HASH_NONFATAL_OOM(x)
128 |
129 | #endif
130 |
131 | /* initial number of buckets */
132 | #define HASH_INITIAL_NUM_BUCKETS 32U /* initial number of buckets */
133 | #define HASH_INITIAL_NUM_BUCKETS_LOG2 5U /* lg2 of initial number of buckets */
134 | #define HASH_BKT_CAPACITY_THRESH 10U /* expand when bucket count reaches */
135 |
136 | /* calculate the element whose hash handle address is hhp */
137 | #define ELMT_FROM_HH(tbl,hhp) ((void*)(((char*)(hhp)) - ((tbl)->hho)))
138 | /* calculate the hash handle from element address elp */
139 | #define HH_FROM_ELMT(tbl,elp) ((UT_hash_handle*)(void*)(((char*)(elp)) + ((tbl)->hho)))
140 |
141 | #define HASH_ROLLBACK_BKT(hh, head, itemptrhh) \
142 | do { \
143 | struct UT_hash_handle *_hd_hh_item = (itemptrhh); \
144 | unsigned _hd_bkt; \
145 | HASH_TO_BKT(_hd_hh_item->hashv, (head)->hh.tbl->num_buckets, _hd_bkt); \
146 | (head)->hh.tbl->buckets[_hd_bkt].count++; \
147 | _hd_hh_item->hh_next = NULL; \
148 | _hd_hh_item->hh_prev = NULL; \
149 | } while (0)
150 |
151 | #define HASH_VALUE(keyptr,keylen,hashv) \
152 | do { \
153 | HASH_FUNCTION(keyptr, keylen, hashv); \
154 | } while (0)
155 |
156 | #define HASH_FIND_BYHASHVALUE(hh,head,keyptr,keylen,hashval,out) \
157 | do { \
158 | (out) = NULL; \
159 | if (head) { \
160 | unsigned _hf_bkt; \
161 | HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _hf_bkt); \
162 | if (HASH_BLOOM_TEST((head)->hh.tbl, hashval) != 0) { \
163 | HASH_FIND_IN_BKT((head)->hh.tbl, hh, (head)->hh.tbl->buckets[ _hf_bkt ], keyptr, keylen, hashval, out); \
164 | } \
165 | } \
166 | } while (0)
167 |
168 | #define HASH_FIND(hh,head,keyptr,keylen,out) \
169 | do { \
170 | (out) = NULL; \
171 | if (head) { \
172 | unsigned _hf_hashv; \
173 | HASH_VALUE(keyptr, keylen, _hf_hashv); \
174 | HASH_FIND_BYHASHVALUE(hh, head, keyptr, keylen, _hf_hashv, out); \
175 | } \
176 | } while (0)
177 |
178 | #ifdef HASH_BLOOM
179 | #define HASH_BLOOM_BITLEN (1UL << HASH_BLOOM)
180 | #define HASH_BLOOM_BYTELEN (HASH_BLOOM_BITLEN/8UL) + (((HASH_BLOOM_BITLEN%8UL)!=0UL) ? 1UL : 0UL)
181 | #define HASH_BLOOM_MAKE(tbl,oomed) \
182 | do { \
183 | (tbl)->bloom_nbits = HASH_BLOOM; \
184 | (tbl)->bloom_bv = (uint8_t*)uthash_malloc(HASH_BLOOM_BYTELEN); \
185 | if (!(tbl)->bloom_bv) { \
186 | HASH_RECORD_OOM(oomed); \
187 | } else { \
188 | uthash_bzero((tbl)->bloom_bv, HASH_BLOOM_BYTELEN); \
189 | (tbl)->bloom_sig = HASH_BLOOM_SIGNATURE; \
190 | } \
191 | } while (0)
192 |
193 | #define HASH_BLOOM_FREE(tbl) \
194 | do { \
195 | uthash_free((tbl)->bloom_bv, HASH_BLOOM_BYTELEN); \
196 | } while (0)
197 |
198 | #define HASH_BLOOM_BITSET(bv,idx) (bv[(idx)/8U] |= (1U << ((idx)%8U)))
199 | #define HASH_BLOOM_BITTEST(bv,idx) (bv[(idx)/8U] & (1U << ((idx)%8U)))
200 |
201 | #define HASH_BLOOM_ADD(tbl,hashv) \
202 | HASH_BLOOM_BITSET((tbl)->bloom_bv, ((hashv) & (uint32_t)((1UL << (tbl)->bloom_nbits) - 1U)))
203 |
204 | #define HASH_BLOOM_TEST(tbl,hashv) \
205 | HASH_BLOOM_BITTEST((tbl)->bloom_bv, ((hashv) & (uint32_t)((1UL << (tbl)->bloom_nbits) - 1U)))
206 |
207 | #else
208 | #define HASH_BLOOM_MAKE(tbl,oomed)
209 | #define HASH_BLOOM_FREE(tbl)
210 | #define HASH_BLOOM_ADD(tbl,hashv)
211 | #define HASH_BLOOM_TEST(tbl,hashv) (1)
212 | #define HASH_BLOOM_BYTELEN 0U
213 | #endif
214 |
215 | #define HASH_MAKE_TABLE(hh,head,oomed) \
216 | do { \
217 | (head)->hh.tbl = (UT_hash_table*)uthash_malloc(sizeof(UT_hash_table)); \
218 | if (!(head)->hh.tbl) { \
219 | HASH_RECORD_OOM(oomed); \
220 | } else { \
221 | uthash_bzero((head)->hh.tbl, sizeof(UT_hash_table)); \
222 | (head)->hh.tbl->tail = &((head)->hh); \
223 | (head)->hh.tbl->num_buckets = HASH_INITIAL_NUM_BUCKETS; \
224 | (head)->hh.tbl->log2_num_buckets = HASH_INITIAL_NUM_BUCKETS_LOG2; \
225 | (head)->hh.tbl->hho = (char*)(&(head)->hh) - (char*)(head); \
226 | (head)->hh.tbl->buckets = (UT_hash_bucket*)uthash_malloc( \
227 | HASH_INITIAL_NUM_BUCKETS * sizeof(struct UT_hash_bucket)); \
228 | (head)->hh.tbl->signature = HASH_SIGNATURE; \
229 | if (!(head)->hh.tbl->buckets) { \
230 | HASH_RECORD_OOM(oomed); \
231 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \
232 | } else { \
233 | uthash_bzero((head)->hh.tbl->buckets, \
234 | HASH_INITIAL_NUM_BUCKETS * sizeof(struct UT_hash_bucket)); \
235 | HASH_BLOOM_MAKE((head)->hh.tbl, oomed); \
236 | IF_HASH_NONFATAL_OOM( \
237 | if (oomed) { \
238 | uthash_free((head)->hh.tbl->buckets, \
239 | HASH_INITIAL_NUM_BUCKETS*sizeof(struct UT_hash_bucket)); \
240 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \
241 | } \
242 | ) \
243 | } \
244 | } \
245 | } while (0)
246 |
247 | #define HASH_REPLACE_BYHASHVALUE_INORDER(hh,head,fieldname,keylen_in,hashval,add,replaced,cmpfcn) \
248 | do { \
249 | (replaced) = NULL; \
250 | HASH_FIND_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, replaced); \
251 | if (replaced) { \
252 | HASH_DELETE(hh, head, replaced); \
253 | } \
254 | HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, &((add)->fieldname), keylen_in, hashval, add, cmpfcn); \
255 | } while (0)
256 |
257 | #define HASH_REPLACE_BYHASHVALUE(hh,head,fieldname,keylen_in,hashval,add,replaced) \
258 | do { \
259 | (replaced) = NULL; \
260 | HASH_FIND_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, replaced); \
261 | if (replaced) { \
262 | HASH_DELETE(hh, head, replaced); \
263 | } \
264 | HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, add); \
265 | } while (0)
266 |
267 | #define HASH_REPLACE(hh,head,fieldname,keylen_in,add,replaced) \
268 | do { \
269 | unsigned _hr_hashv; \
270 | HASH_VALUE(&((add)->fieldname), keylen_in, _hr_hashv); \
271 | HASH_REPLACE_BYHASHVALUE(hh, head, fieldname, keylen_in, _hr_hashv, add, replaced); \
272 | } while (0)
273 |
274 | #define HASH_REPLACE_INORDER(hh,head,fieldname,keylen_in,add,replaced,cmpfcn) \
275 | do { \
276 | unsigned _hr_hashv; \
277 | HASH_VALUE(&((add)->fieldname), keylen_in, _hr_hashv); \
278 | HASH_REPLACE_BYHASHVALUE_INORDER(hh, head, fieldname, keylen_in, _hr_hashv, add, replaced, cmpfcn); \
279 | } while (0)
280 |
281 | #define HASH_APPEND_LIST(hh, head, add) \
282 | do { \
283 | (add)->hh.next = NULL; \
284 | (add)->hh.prev = ELMT_FROM_HH((head)->hh.tbl, (head)->hh.tbl->tail); \
285 | (head)->hh.tbl->tail->next = (add); \
286 | (head)->hh.tbl->tail = &((add)->hh); \
287 | } while (0)
288 |
289 | #define HASH_AKBI_INNER_LOOP(hh,head,add,cmpfcn) \
290 | do { \
291 | do { \
292 | if (cmpfcn(DECLTYPE(head)(_hs_iter), add) > 0) { \
293 | break; \
294 | } \
295 | } while ((_hs_iter = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->next)); \
296 | } while (0)
297 |
298 | #ifdef NO_DECLTYPE
299 | #undef HASH_AKBI_INNER_LOOP
300 | #define HASH_AKBI_INNER_LOOP(hh,head,add,cmpfcn) \
301 | do { \
302 | char *_hs_saved_head = (char*)(head); \
303 | do { \
304 | DECLTYPE_ASSIGN(head, _hs_iter); \
305 | if (cmpfcn(head, add) > 0) { \
306 | DECLTYPE_ASSIGN(head, _hs_saved_head); \
307 | break; \
308 | } \
309 | DECLTYPE_ASSIGN(head, _hs_saved_head); \
310 | } while ((_hs_iter = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->next)); \
311 | } while (0)
312 | #endif
313 |
314 | #if HASH_NONFATAL_OOM
315 |
316 | #define HASH_ADD_TO_TABLE(hh,head,keyptr,keylen_in,hashval,add,oomed) \
317 | do { \
318 | if (!(oomed)) { \
319 | unsigned _ha_bkt; \
320 | (head)->hh.tbl->num_items++; \
321 | HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _ha_bkt); \
322 | HASH_ADD_TO_BKT((head)->hh.tbl->buckets[_ha_bkt], hh, &(add)->hh, oomed); \
323 | if (oomed) { \
324 | HASH_ROLLBACK_BKT(hh, head, &(add)->hh); \
325 | HASH_DELETE_HH(hh, head, &(add)->hh); \
326 | (add)->hh.tbl = NULL; \
327 | uthash_nonfatal_oom(add); \
328 | } else { \
329 | HASH_BLOOM_ADD((head)->hh.tbl, hashval); \
330 | HASH_EMIT_KEY(hh, head, keyptr, keylen_in); \
331 | } \
332 | } else { \
333 | (add)->hh.tbl = NULL; \
334 | uthash_nonfatal_oom(add); \
335 | } \
336 | } while (0)
337 |
338 | #else
339 |
340 | #define HASH_ADD_TO_TABLE(hh,head,keyptr,keylen_in,hashval,add,oomed) \
341 | do { \
342 | unsigned _ha_bkt; \
343 | (head)->hh.tbl->num_items++; \
344 | HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _ha_bkt); \
345 | HASH_ADD_TO_BKT((head)->hh.tbl->buckets[_ha_bkt], hh, &(add)->hh, oomed); \
346 | HASH_BLOOM_ADD((head)->hh.tbl, hashval); \
347 | HASH_EMIT_KEY(hh, head, keyptr, keylen_in); \
348 | } while (0)
349 |
350 | #endif
351 |
352 |
353 | #define HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh,head,keyptr,keylen_in,hashval,add,cmpfcn) \
354 | do { \
355 | IF_HASH_NONFATAL_OOM( int _ha_oomed = 0; ) \
356 | (add)->hh.hashv = (hashval); \
357 | (add)->hh.key = (char*) (keyptr); \
358 | (add)->hh.keylen = (unsigned) (keylen_in); \
359 | if (!(head)) { \
360 | (add)->hh.next = NULL; \
361 | (add)->hh.prev = NULL; \
362 | HASH_MAKE_TABLE(hh, add, _ha_oomed); \
363 | IF_HASH_NONFATAL_OOM( if (!_ha_oomed) { ) \
364 | (head) = (add); \
365 | IF_HASH_NONFATAL_OOM( } ) \
366 | } else { \
367 | void *_hs_iter = (head); \
368 | (add)->hh.tbl = (head)->hh.tbl; \
369 | HASH_AKBI_INNER_LOOP(hh, head, add, cmpfcn); \
370 | if (_hs_iter) { \
371 | (add)->hh.next = _hs_iter; \
372 | if (((add)->hh.prev = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->prev)) { \
373 | HH_FROM_ELMT((head)->hh.tbl, (add)->hh.prev)->next = (add); \
374 | } else { \
375 | (head) = (add); \
376 | } \
377 | HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->prev = (add); \
378 | } else { \
379 | HASH_APPEND_LIST(hh, head, add); \
380 | } \
381 | } \
382 | HASH_ADD_TO_TABLE(hh, head, keyptr, keylen_in, hashval, add, _ha_oomed); \
383 | HASH_FSCK(hh, head, "HASH_ADD_KEYPTR_BYHASHVALUE_INORDER"); \
384 | } while (0)
385 |
386 | #define HASH_ADD_KEYPTR_INORDER(hh,head,keyptr,keylen_in,add,cmpfcn) \
387 | do { \
388 | unsigned _hs_hashv; \
389 | HASH_VALUE(keyptr, keylen_in, _hs_hashv); \
390 | HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, keyptr, keylen_in, _hs_hashv, add, cmpfcn); \
391 | } while (0)
392 |
393 | #define HASH_ADD_BYHASHVALUE_INORDER(hh,head,fieldname,keylen_in,hashval,add,cmpfcn) \
394 | HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, &((add)->fieldname), keylen_in, hashval, add, cmpfcn)
395 |
396 | #define HASH_ADD_INORDER(hh,head,fieldname,keylen_in,add,cmpfcn) \
397 | HASH_ADD_KEYPTR_INORDER(hh, head, &((add)->fieldname), keylen_in, add, cmpfcn)
398 |
399 | #define HASH_ADD_KEYPTR_BYHASHVALUE(hh,head,keyptr,keylen_in,hashval,add) \
400 | do { \
401 | IF_HASH_NONFATAL_OOM( int _ha_oomed = 0; ) \
402 | (add)->hh.hashv = (hashval); \
403 | (add)->hh.key = (const void*) (keyptr); \
404 | (add)->hh.keylen = (unsigned) (keylen_in); \
405 | if (!(head)) { \
406 | (add)->hh.next = NULL; \
407 | (add)->hh.prev = NULL; \
408 | HASH_MAKE_TABLE(hh, add, _ha_oomed); \
409 | IF_HASH_NONFATAL_OOM( if (!_ha_oomed) { ) \
410 | (head) = (add); \
411 | IF_HASH_NONFATAL_OOM( } ) \
412 | } else { \
413 | (add)->hh.tbl = (head)->hh.tbl; \
414 | HASH_APPEND_LIST(hh, head, add); \
415 | } \
416 | HASH_ADD_TO_TABLE(hh, head, keyptr, keylen_in, hashval, add, _ha_oomed); \
417 | HASH_FSCK(hh, head, "HASH_ADD_KEYPTR_BYHASHVALUE"); \
418 | } while (0)
419 |
420 | #define HASH_ADD_KEYPTR(hh,head,keyptr,keylen_in,add) \
421 | do { \
422 | unsigned _ha_hashv; \
423 | HASH_VALUE(keyptr, keylen_in, _ha_hashv); \
424 | HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, keyptr, keylen_in, _ha_hashv, add); \
425 | } while (0)
426 |
427 | #define HASH_ADD_BYHASHVALUE(hh,head,fieldname,keylen_in,hashval,add) \
428 | HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, add)
429 |
430 | #define HASH_ADD(hh,head,fieldname,keylen_in,add) \
431 | HASH_ADD_KEYPTR(hh, head, &((add)->fieldname), keylen_in, add)
432 |
433 | #define HASH_TO_BKT(hashv,num_bkts,bkt) \
434 | do { \
435 | bkt = ((hashv) & ((num_bkts) - 1U)); \
436 | } while (0)
437 |
438 | /* delete "delptr" from the hash table.
439 | * "the usual" patch-up process for the app-order doubly-linked-list.
440 | * The use of _hd_hh_del below deserves special explanation.
441 | * These used to be expressed using (delptr) but that led to a bug
442 | * if someone used the same symbol for the head and deletee, like
443 | * HASH_DELETE(hh,users,users);
444 | * We want that to work, but by changing the head (users) below
445 | * we were forfeiting our ability to further refer to the deletee (users)
446 | * in the patch-up process. Solution: use scratch space to
447 | * copy the deletee pointer, then the latter references are via that
448 | * scratch pointer rather than through the repointed (users) symbol.
449 | */
450 | #define HASH_DELETE(hh,head,delptr) \
451 | HASH_DELETE_HH(hh, head, &(delptr)->hh)
452 |
453 | #define HASH_DELETE_HH(hh,head,delptrhh) \
454 | do { \
455 | const struct UT_hash_handle *_hd_hh_del = (delptrhh); \
456 | if ((_hd_hh_del->prev == NULL) && (_hd_hh_del->next == NULL)) { \
457 | HASH_BLOOM_FREE((head)->hh.tbl); \
458 | uthash_free((head)->hh.tbl->buckets, \
459 | (head)->hh.tbl->num_buckets * sizeof(struct UT_hash_bucket)); \
460 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \
461 | (head) = NULL; \
462 | } else { \
463 | unsigned _hd_bkt; \
464 | if (_hd_hh_del == (head)->hh.tbl->tail) { \
465 | (head)->hh.tbl->tail = HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->prev); \
466 | } \
467 | if (_hd_hh_del->prev != NULL) { \
468 | HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->prev)->next = _hd_hh_del->next; \
469 | } else { \
470 | DECLTYPE_ASSIGN(head, _hd_hh_del->next); \
471 | } \
472 | if (_hd_hh_del->next != NULL) { \
473 | HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->next)->prev = _hd_hh_del->prev; \
474 | } \
475 | HASH_TO_BKT(_hd_hh_del->hashv, (head)->hh.tbl->num_buckets, _hd_bkt); \
476 | HASH_DEL_IN_BKT((head)->hh.tbl->buckets[_hd_bkt], _hd_hh_del); \
477 | (head)->hh.tbl->num_items--; \
478 | } \
479 | HASH_FSCK(hh, head, "HASH_DELETE_HH"); \
480 | } while (0)
481 |
482 | /* convenience forms of HASH_FIND/HASH_ADD/HASH_DEL */
483 | #define HASH_FIND_STR(head,findstr,out) \
484 | do { \
485 | unsigned _uthash_hfstr_keylen = (unsigned)uthash_strlen(findstr); \
486 | HASH_FIND(hh, head, findstr, _uthash_hfstr_keylen, out); \
487 | } while (0)
488 | #define HASH_ADD_STR(head,strfield,add) \
489 | do { \
490 | unsigned _uthash_hastr_keylen = (unsigned)uthash_strlen((add)->strfield); \
491 | HASH_ADD(hh, head, strfield[0], _uthash_hastr_keylen, add); \
492 | } while (0)
493 | #define HASH_REPLACE_STR(head,strfield,add,replaced) \
494 | do { \
495 | unsigned _uthash_hrstr_keylen = (unsigned)uthash_strlen((add)->strfield); \
496 | HASH_REPLACE(hh, head, strfield[0], _uthash_hrstr_keylen, add, replaced); \
497 | } while (0)
498 | #define HASH_FIND_INT(head,findint,out) \
499 | HASH_FIND(hh,head,findint,sizeof(int),out)
500 | #define HASH_ADD_INT(head,intfield,add) \
501 | HASH_ADD(hh,head,intfield,sizeof(int),add)
502 | #define HASH_REPLACE_INT(head,intfield,add,replaced) \
503 | HASH_REPLACE(hh,head,intfield,sizeof(int),add,replaced)
504 | #define HASH_FIND_PTR(head,findptr,out) \
505 | HASH_FIND(hh,head,findptr,sizeof(void *),out)
506 | #define HASH_ADD_PTR(head,ptrfield,add) \
507 | HASH_ADD(hh,head,ptrfield,sizeof(void *),add)
508 | #define HASH_REPLACE_PTR(head,ptrfield,add,replaced) \
509 | HASH_REPLACE(hh,head,ptrfield,sizeof(void *),add,replaced)
510 | #define HASH_DEL(head,delptr) \
511 | HASH_DELETE(hh,head,delptr)
512 |
513 | /* HASH_FSCK checks hash integrity on every add/delete when HASH_DEBUG is defined.
514 | * This is for uthash developer only; it compiles away if HASH_DEBUG isn't defined.
515 | */
516 | #ifdef HASH_DEBUG
517 | #include /* fprintf, stderr */
518 | #define HASH_OOPS(...) do { fprintf(stderr, __VA_ARGS__); exit(-1); } while (0)
519 | #define HASH_FSCK(hh,head,where) \
520 | do { \
521 | struct UT_hash_handle *_thh; \
522 | if (head) { \
523 | unsigned _bkt_i; \
524 | unsigned _count = 0; \
525 | char *_prev; \
526 | for (_bkt_i = 0; _bkt_i < (head)->hh.tbl->num_buckets; ++_bkt_i) { \
527 | unsigned _bkt_count = 0; \
528 | _thh = (head)->hh.tbl->buckets[_bkt_i].hh_head; \
529 | _prev = NULL; \
530 | while (_thh) { \
531 | if (_prev != (char*)(_thh->hh_prev)) { \
532 | HASH_OOPS("%s: invalid hh_prev %p, actual %p\n", \
533 | (where), (void*)_thh->hh_prev, (void*)_prev); \
534 | } \
535 | _bkt_count++; \
536 | _prev = (char*)(_thh); \
537 | _thh = _thh->hh_next; \
538 | } \
539 | _count += _bkt_count; \
540 | if ((head)->hh.tbl->buckets[_bkt_i].count != _bkt_count) { \
541 | HASH_OOPS("%s: invalid bucket count %u, actual %u\n", \
542 | (where), (head)->hh.tbl->buckets[_bkt_i].count, _bkt_count); \
543 | } \
544 | } \
545 | if (_count != (head)->hh.tbl->num_items) { \
546 | HASH_OOPS("%s: invalid hh item count %u, actual %u\n", \
547 | (where), (head)->hh.tbl->num_items, _count); \
548 | } \
549 | _count = 0; \
550 | _prev = NULL; \
551 | _thh = &(head)->hh; \
552 | while (_thh) { \
553 | _count++; \
554 | if (_prev != (char*)_thh->prev) { \
555 | HASH_OOPS("%s: invalid prev %p, actual %p\n", \
556 | (where), (void*)_thh->prev, (void*)_prev); \
557 | } \
558 | _prev = (char*)ELMT_FROM_HH((head)->hh.tbl, _thh); \
559 | _thh = (_thh->next ? HH_FROM_ELMT((head)->hh.tbl, _thh->next) : NULL); \
560 | } \
561 | if (_count != (head)->hh.tbl->num_items) { \
562 | HASH_OOPS("%s: invalid app item count %u, actual %u\n", \
563 | (where), (head)->hh.tbl->num_items, _count); \
564 | } \
565 | } \
566 | } while (0)
567 | #else
568 | #define HASH_FSCK(hh,head,where)
569 | #endif
570 |
571 | /* When compiled with -DHASH_EMIT_KEYS, length-prefixed keys are emitted to
572 | * the descriptor to which this macro is defined for tuning the hash function.
573 | * The app can #include to get the prototype for write(2). */
574 | #ifdef HASH_EMIT_KEYS
575 | #define HASH_EMIT_KEY(hh,head,keyptr,fieldlen) \
576 | do { \
577 | unsigned _klen = fieldlen; \
578 | write(HASH_EMIT_KEYS, &_klen, sizeof(_klen)); \
579 | write(HASH_EMIT_KEYS, keyptr, (unsigned long)fieldlen); \
580 | } while (0)
581 | #else
582 | #define HASH_EMIT_KEY(hh,head,keyptr,fieldlen)
583 | #endif
584 |
585 | /* The Bernstein hash function, used in Perl prior to v5.6. Note (x<<5+x)=x*33. */
586 | #define HASH_BER(key,keylen,hashv) \
587 | do { \
588 | unsigned _hb_keylen = (unsigned)keylen; \
589 | const unsigned char *_hb_key = (const unsigned char*)(key); \
590 | (hashv) = 0; \
591 | while (_hb_keylen-- != 0U) { \
592 | (hashv) = (((hashv) << 5) + (hashv)) + *_hb_key++; \
593 | } \
594 | } while (0)
595 |
596 |
597 | /* SAX/FNV/OAT/JEN hash functions are macro variants of those listed at
598 | * http://eternallyconfuzzled.com/tuts/algorithms/jsw_tut_hashing.aspx
599 | * (archive link: https://archive.is/Ivcan )
600 | */
601 | #define HASH_SAX(key,keylen,hashv) \
602 | do { \
603 | unsigned _sx_i; \
604 | const unsigned char *_hs_key = (const unsigned char*)(key); \
605 | hashv = 0; \
606 | for (_sx_i=0; _sx_i < keylen; _sx_i++) { \
607 | hashv ^= (hashv << 5) + (hashv >> 2) + _hs_key[_sx_i]; \
608 | } \
609 | } while (0)
610 | /* FNV-1a variation */
611 | #define HASH_FNV(key,keylen,hashv) \
612 | do { \
613 | unsigned _fn_i; \
614 | const unsigned char *_hf_key = (const unsigned char*)(key); \
615 | (hashv) = 2166136261U; \
616 | for (_fn_i=0; _fn_i < keylen; _fn_i++) { \
617 | hashv = hashv ^ _hf_key[_fn_i]; \
618 | hashv = hashv * 16777619U; \
619 | } \
620 | } while (0)
621 |
622 | #define HASH_OAT(key,keylen,hashv) \
623 | do { \
624 | unsigned _ho_i; \
625 | const unsigned char *_ho_key=(const unsigned char*)(key); \
626 | hashv = 0; \
627 | for(_ho_i=0; _ho_i < keylen; _ho_i++) { \
628 | hashv += _ho_key[_ho_i]; \
629 | hashv += (hashv << 10); \
630 | hashv ^= (hashv >> 6); \
631 | } \
632 | hashv += (hashv << 3); \
633 | hashv ^= (hashv >> 11); \
634 | hashv += (hashv << 15); \
635 | } while (0)
636 |
637 | #define HASH_JEN_MIX(a,b,c) \
638 | do { \
639 | a -= b; a -= c; a ^= ( c >> 13 ); \
640 | b -= c; b -= a; b ^= ( a << 8 ); \
641 | c -= a; c -= b; c ^= ( b >> 13 ); \
642 | a -= b; a -= c; a ^= ( c >> 12 ); \
643 | b -= c; b -= a; b ^= ( a << 16 ); \
644 | c -= a; c -= b; c ^= ( b >> 5 ); \
645 | a -= b; a -= c; a ^= ( c >> 3 ); \
646 | b -= c; b -= a; b ^= ( a << 10 ); \
647 | c -= a; c -= b; c ^= ( b >> 15 ); \
648 | } while (0)
649 |
650 | #define HASH_JEN(key,keylen,hashv) \
651 | do { \
652 | unsigned _hj_i,_hj_j,_hj_k; \
653 | unsigned const char *_hj_key=(unsigned const char*)(key); \
654 | hashv = 0xfeedbeefu; \
655 | _hj_i = _hj_j = 0x9e3779b9u; \
656 | _hj_k = (unsigned)(keylen); \
657 | while (_hj_k >= 12U) { \
658 | _hj_i += (_hj_key[0] + ( (unsigned)_hj_key[1] << 8 ) \
659 | + ( (unsigned)_hj_key[2] << 16 ) \
660 | + ( (unsigned)_hj_key[3] << 24 ) ); \
661 | _hj_j += (_hj_key[4] + ( (unsigned)_hj_key[5] << 8 ) \
662 | + ( (unsigned)_hj_key[6] << 16 ) \
663 | + ( (unsigned)_hj_key[7] << 24 ) ); \
664 | hashv += (_hj_key[8] + ( (unsigned)_hj_key[9] << 8 ) \
665 | + ( (unsigned)_hj_key[10] << 16 ) \
666 | + ( (unsigned)_hj_key[11] << 24 ) ); \
667 | \
668 | HASH_JEN_MIX(_hj_i, _hj_j, hashv); \
669 | \
670 | _hj_key += 12; \
671 | _hj_k -= 12U; \
672 | } \
673 | hashv += (unsigned)(keylen); \
674 | switch ( _hj_k ) { \
675 | case 11: hashv += ( (unsigned)_hj_key[10] << 24 ); /* FALLTHROUGH */ \
676 | case 10: hashv += ( (unsigned)_hj_key[9] << 16 ); /* FALLTHROUGH */ \
677 | case 9: hashv += ( (unsigned)_hj_key[8] << 8 ); /* FALLTHROUGH */ \
678 | case 8: _hj_j += ( (unsigned)_hj_key[7] << 24 ); /* FALLTHROUGH */ \
679 | case 7: _hj_j += ( (unsigned)_hj_key[6] << 16 ); /* FALLTHROUGH */ \
680 | case 6: _hj_j += ( (unsigned)_hj_key[5] << 8 ); /* FALLTHROUGH */ \
681 | case 5: _hj_j += _hj_key[4]; /* FALLTHROUGH */ \
682 | case 4: _hj_i += ( (unsigned)_hj_key[3] << 24 ); /* FALLTHROUGH */ \
683 | case 3: _hj_i += ( (unsigned)_hj_key[2] << 16 ); /* FALLTHROUGH */ \
684 | case 2: _hj_i += ( (unsigned)_hj_key[1] << 8 ); /* FALLTHROUGH */ \
685 | case 1: _hj_i += _hj_key[0]; /* FALLTHROUGH */ \
686 | default: ; \
687 | } \
688 | HASH_JEN_MIX(_hj_i, _hj_j, hashv); \
689 | } while (0)
690 |
691 | /* The Paul Hsieh hash function */
692 | #undef get16bits
693 | #if (defined(__GNUC__) && defined(__i386__)) || defined(__WATCOMC__) \
694 | || defined(_MSC_VER) || defined (__BORLANDC__) || defined (__TURBOC__)
695 | #define get16bits(d) (*((const uint16_t *) (d)))
696 | #endif
697 |
698 | #if !defined (get16bits)
699 | #define get16bits(d) ((((uint32_t)(((const uint8_t *)(d))[1])) << 8) \
700 | +(uint32_t)(((const uint8_t *)(d))[0]) )
701 | #endif
702 | #define HASH_SFH(key,keylen,hashv) \
703 | do { \
704 | unsigned const char *_sfh_key=(unsigned const char*)(key); \
705 | uint32_t _sfh_tmp, _sfh_len = (uint32_t)keylen; \
706 | \
707 | unsigned _sfh_rem = _sfh_len & 3U; \
708 | _sfh_len >>= 2; \
709 | hashv = 0xcafebabeu; \
710 | \
711 | /* Main loop */ \
712 | for (;_sfh_len > 0U; _sfh_len--) { \
713 | hashv += get16bits (_sfh_key); \
714 | _sfh_tmp = ((uint32_t)(get16bits (_sfh_key+2)) << 11) ^ hashv; \
715 | hashv = (hashv << 16) ^ _sfh_tmp; \
716 | _sfh_key += 2U*sizeof (uint16_t); \
717 | hashv += hashv >> 11; \
718 | } \
719 | \
720 | /* Handle end cases */ \
721 | switch (_sfh_rem) { \
722 | case 3: hashv += get16bits (_sfh_key); \
723 | hashv ^= hashv << 16; \
724 | hashv ^= (uint32_t)(_sfh_key[sizeof (uint16_t)]) << 18; \
725 | hashv += hashv >> 11; \
726 | break; \
727 | case 2: hashv += get16bits (_sfh_key); \
728 | hashv ^= hashv << 11; \
729 | hashv += hashv >> 17; \
730 | break; \
731 | case 1: hashv += *_sfh_key; \
732 | hashv ^= hashv << 10; \
733 | hashv += hashv >> 1; \
734 | break; \
735 | default: ; \
736 | } \
737 | \
738 | /* Force "avalanching" of final 127 bits */ \
739 | hashv ^= hashv << 3; \
740 | hashv += hashv >> 5; \
741 | hashv ^= hashv << 4; \
742 | hashv += hashv >> 17; \
743 | hashv ^= hashv << 25; \
744 | hashv += hashv >> 6; \
745 | } while (0)
746 |
747 | /* iterate over items in a known bucket to find desired item */
748 | #define HASH_FIND_IN_BKT(tbl,hh,head,keyptr,keylen_in,hashval,out) \
749 | do { \
750 | if ((head).hh_head != NULL) { \
751 | DECLTYPE_ASSIGN(out, ELMT_FROM_HH(tbl, (head).hh_head)); \
752 | } else { \
753 | (out) = NULL; \
754 | } \
755 | while ((out) != NULL) { \
756 | if ((out)->hh.hashv == (hashval) && (out)->hh.keylen == (keylen_in)) { \
757 | if (HASH_KEYCMP((out)->hh.key, keyptr, keylen_in) == 0) { \
758 | break; \
759 | } \
760 | } \
761 | if ((out)->hh.hh_next != NULL) { \
762 | DECLTYPE_ASSIGN(out, ELMT_FROM_HH(tbl, (out)->hh.hh_next)); \
763 | } else { \
764 | (out) = NULL; \
765 | } \
766 | } \
767 | } while (0)
768 |
769 | /* add an item to a bucket */
770 | #define HASH_ADD_TO_BKT(head,hh,addhh,oomed) \
771 | do { \
772 | UT_hash_bucket *_ha_head = &(head); \
773 | _ha_head->count++; \
774 | (addhh)->hh_next = _ha_head->hh_head; \
775 | (addhh)->hh_prev = NULL; \
776 | if (_ha_head->hh_head != NULL) { \
777 | _ha_head->hh_head->hh_prev = (addhh); \
778 | } \
779 | _ha_head->hh_head = (addhh); \
780 | if ((_ha_head->count >= ((_ha_head->expand_mult + 1U) * HASH_BKT_CAPACITY_THRESH)) \
781 | && !(addhh)->tbl->noexpand) { \
782 | HASH_EXPAND_BUCKETS(addhh,(addhh)->tbl, oomed); \
783 | IF_HASH_NONFATAL_OOM( \
784 | if (oomed) { \
785 | HASH_DEL_IN_BKT(head,addhh); \
786 | } \
787 | ) \
788 | } \
789 | } while (0)
790 |
791 | /* remove an item from a given bucket */
792 | #define HASH_DEL_IN_BKT(head,delhh) \
793 | do { \
794 | UT_hash_bucket *_hd_head = &(head); \
795 | _hd_head->count--; \
796 | if (_hd_head->hh_head == (delhh)) { \
797 | _hd_head->hh_head = (delhh)->hh_next; \
798 | } \
799 | if ((delhh)->hh_prev) { \
800 | (delhh)->hh_prev->hh_next = (delhh)->hh_next; \
801 | } \
802 | if ((delhh)->hh_next) { \
803 | (delhh)->hh_next->hh_prev = (delhh)->hh_prev; \
804 | } \
805 | } while (0)
806 |
807 | /* Bucket expansion has the effect of doubling the number of buckets
808 | * and redistributing the items into the new buckets. Ideally the
809 | * items will distribute more or less evenly into the new buckets
810 | * (the extent to which this is true is a measure of the quality of
811 | * the hash function as it applies to the key domain).
812 | *
813 | * With the items distributed into more buckets, the chain length
814 | * (item count) in each bucket is reduced. Thus by expanding buckets
815 | * the hash keeps a bound on the chain length. This bounded chain
816 | * length is the essence of how a hash provides constant time lookup.
817 | *
818 | * The calculation of tbl->ideal_chain_maxlen below deserves some
819 | * explanation. First, keep in mind that we're calculating the ideal
820 | * maximum chain length based on the *new* (doubled) bucket count.
821 | * In fractions this is just n/b (n=number of items,b=new num buckets).
822 | * Since the ideal chain length is an integer, we want to calculate
823 | * ceil(n/b). We don't depend on floating point arithmetic in this
824 | * hash, so to calculate ceil(n/b) with integers we could write
825 | *
826 | * ceil(n/b) = (n/b) + ((n%b)?1:0)
827 | *
828 | * and in fact a previous version of this hash did just that.
829 | * But now we have improved things a bit by recognizing that b is
830 | * always a power of two. We keep its base 2 log handy (call it lb),
831 | * so now we can write this with a bit shift and logical AND:
832 | *
833 | * ceil(n/b) = (n>>lb) + ( (n & (b-1)) ? 1:0)
834 | *
835 | */
836 | #define HASH_EXPAND_BUCKETS(hh,tbl,oomed) \
837 | do { \
838 | unsigned _he_bkt; \
839 | unsigned _he_bkt_i; \
840 | struct UT_hash_handle *_he_thh, *_he_hh_nxt; \
841 | UT_hash_bucket *_he_new_buckets, *_he_newbkt; \
842 | _he_new_buckets = (UT_hash_bucket*)uthash_malloc( \
843 | sizeof(struct UT_hash_bucket) * (tbl)->num_buckets * 2U); \
844 | if (!_he_new_buckets) { \
845 | HASH_RECORD_OOM(oomed); \
846 | } else { \
847 | uthash_bzero(_he_new_buckets, \
848 | sizeof(struct UT_hash_bucket) * (tbl)->num_buckets * 2U); \
849 | (tbl)->ideal_chain_maxlen = \
850 | ((tbl)->num_items >> ((tbl)->log2_num_buckets+1U)) + \
851 | ((((tbl)->num_items & (((tbl)->num_buckets*2U)-1U)) != 0U) ? 1U : 0U); \
852 | (tbl)->nonideal_items = 0; \
853 | for (_he_bkt_i = 0; _he_bkt_i < (tbl)->num_buckets; _he_bkt_i++) { \
854 | _he_thh = (tbl)->buckets[ _he_bkt_i ].hh_head; \
855 | while (_he_thh != NULL) { \
856 | _he_hh_nxt = _he_thh->hh_next; \
857 | HASH_TO_BKT(_he_thh->hashv, (tbl)->num_buckets * 2U, _he_bkt); \
858 | _he_newbkt = &(_he_new_buckets[_he_bkt]); \
859 | if (++(_he_newbkt->count) > (tbl)->ideal_chain_maxlen) { \
860 | (tbl)->nonideal_items++; \
861 | if (_he_newbkt->count > _he_newbkt->expand_mult * (tbl)->ideal_chain_maxlen) { \
862 | _he_newbkt->expand_mult++; \
863 | } \
864 | } \
865 | _he_thh->hh_prev = NULL; \
866 | _he_thh->hh_next = _he_newbkt->hh_head; \
867 | if (_he_newbkt->hh_head != NULL) { \
868 | _he_newbkt->hh_head->hh_prev = _he_thh; \
869 | } \
870 | _he_newbkt->hh_head = _he_thh; \
871 | _he_thh = _he_hh_nxt; \
872 | } \
873 | } \
874 | uthash_free((tbl)->buckets, (tbl)->num_buckets * sizeof(struct UT_hash_bucket)); \
875 | (tbl)->num_buckets *= 2U; \
876 | (tbl)->log2_num_buckets++; \
877 | (tbl)->buckets = _he_new_buckets; \
878 | (tbl)->ineff_expands = ((tbl)->nonideal_items > ((tbl)->num_items >> 1)) ? \
879 | ((tbl)->ineff_expands+1U) : 0U; \
880 | if ((tbl)->ineff_expands > 1U) { \
881 | (tbl)->noexpand = 1; \
882 | uthash_noexpand_fyi(tbl); \
883 | } \
884 | uthash_expand_fyi(tbl); \
885 | } \
886 | } while (0)
887 |
888 |
889 | /* This is an adaptation of Simon Tatham's O(n log(n)) mergesort */
890 | /* Note that HASH_SORT assumes the hash handle name to be hh.
891 | * HASH_SRT was added to allow the hash handle name to be passed in. */
892 | #define HASH_SORT(head,cmpfcn) HASH_SRT(hh,head,cmpfcn)
893 | #define HASH_SRT(hh,head,cmpfcn) \
894 | do { \
895 | unsigned _hs_i; \
896 | unsigned _hs_looping,_hs_nmerges,_hs_insize,_hs_psize,_hs_qsize; \
897 | struct UT_hash_handle *_hs_p, *_hs_q, *_hs_e, *_hs_list, *_hs_tail; \
898 | if (head != NULL) { \
899 | _hs_insize = 1; \
900 | _hs_looping = 1; \
901 | _hs_list = &((head)->hh); \
902 | while (_hs_looping != 0U) { \
903 | _hs_p = _hs_list; \
904 | _hs_list = NULL; \
905 | _hs_tail = NULL; \
906 | _hs_nmerges = 0; \
907 | while (_hs_p != NULL) { \
908 | _hs_nmerges++; \
909 | _hs_q = _hs_p; \
910 | _hs_psize = 0; \
911 | for (_hs_i = 0; _hs_i < _hs_insize; ++_hs_i) { \
912 | _hs_psize++; \
913 | _hs_q = ((_hs_q->next != NULL) ? \
914 | HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \
915 | if (_hs_q == NULL) { \
916 | break; \
917 | } \
918 | } \
919 | _hs_qsize = _hs_insize; \
920 | while ((_hs_psize != 0U) || ((_hs_qsize != 0U) && (_hs_q != NULL))) { \
921 | if (_hs_psize == 0U) { \
922 | _hs_e = _hs_q; \
923 | _hs_q = ((_hs_q->next != NULL) ? \
924 | HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \
925 | _hs_qsize--; \
926 | } else if ((_hs_qsize == 0U) || (_hs_q == NULL)) { \
927 | _hs_e = _hs_p; \
928 | if (_hs_p != NULL) { \
929 | _hs_p = ((_hs_p->next != NULL) ? \
930 | HH_FROM_ELMT((head)->hh.tbl, _hs_p->next) : NULL); \
931 | } \
932 | _hs_psize--; \
933 | } else if ((cmpfcn( \
934 | DECLTYPE(head)(ELMT_FROM_HH((head)->hh.tbl, _hs_p)), \
935 | DECLTYPE(head)(ELMT_FROM_HH((head)->hh.tbl, _hs_q)) \
936 | )) <= 0) { \
937 | _hs_e = _hs_p; \
938 | if (_hs_p != NULL) { \
939 | _hs_p = ((_hs_p->next != NULL) ? \
940 | HH_FROM_ELMT((head)->hh.tbl, _hs_p->next) : NULL); \
941 | } \
942 | _hs_psize--; \
943 | } else { \
944 | _hs_e = _hs_q; \
945 | _hs_q = ((_hs_q->next != NULL) ? \
946 | HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \
947 | _hs_qsize--; \
948 | } \
949 | if ( _hs_tail != NULL ) { \
950 | _hs_tail->next = ((_hs_e != NULL) ? \
951 | ELMT_FROM_HH((head)->hh.tbl, _hs_e) : NULL); \
952 | } else { \
953 | _hs_list = _hs_e; \
954 | } \
955 | if (_hs_e != NULL) { \
956 | _hs_e->prev = ((_hs_tail != NULL) ? \
957 | ELMT_FROM_HH((head)->hh.tbl, _hs_tail) : NULL); \
958 | } \
959 | _hs_tail = _hs_e; \
960 | } \
961 | _hs_p = _hs_q; \
962 | } \
963 | if (_hs_tail != NULL) { \
964 | _hs_tail->next = NULL; \
965 | } \
966 | if (_hs_nmerges <= 1U) { \
967 | _hs_looping = 0; \
968 | (head)->hh.tbl->tail = _hs_tail; \
969 | DECLTYPE_ASSIGN(head, ELMT_FROM_HH((head)->hh.tbl, _hs_list)); \
970 | } \
971 | _hs_insize *= 2U; \
972 | } \
973 | HASH_FSCK(hh, head, "HASH_SRT"); \
974 | } \
975 | } while (0)
976 |
977 | /* This function selects items from one hash into another hash.
978 | * The end result is that the selected items have dual presence
979 | * in both hashes. There is no copy of the items made; rather
980 | * they are added into the new hash through a secondary hash
981 | * hash handle that must be present in the structure. */
982 | #define HASH_SELECT(hh_dst, dst, hh_src, src, cond) \
983 | do { \
984 | unsigned _src_bkt, _dst_bkt; \
985 | void *_last_elt = NULL, *_elt; \
986 | UT_hash_handle *_src_hh, *_dst_hh, *_last_elt_hh=NULL; \
987 | ptrdiff_t _dst_hho = ((char*)(&(dst)->hh_dst) - (char*)(dst)); \
988 | if ((src) != NULL) { \
989 | for (_src_bkt=0; _src_bkt < (src)->hh_src.tbl->num_buckets; _src_bkt++) { \
990 | for (_src_hh = (src)->hh_src.tbl->buckets[_src_bkt].hh_head; \
991 | _src_hh != NULL; \
992 | _src_hh = _src_hh->hh_next) { \
993 | _elt = ELMT_FROM_HH((src)->hh_src.tbl, _src_hh); \
994 | if (cond(_elt)) { \
995 | IF_HASH_NONFATAL_OOM( int _hs_oomed = 0; ) \
996 | _dst_hh = (UT_hash_handle*)(void*)(((char*)_elt) + _dst_hho); \
997 | _dst_hh->key = _src_hh->key; \
998 | _dst_hh->keylen = _src_hh->keylen; \
999 | _dst_hh->hashv = _src_hh->hashv; \
1000 | _dst_hh->prev = _last_elt; \
1001 | _dst_hh->next = NULL; \
1002 | if (_last_elt_hh != NULL) { \
1003 | _last_elt_hh->next = _elt; \
1004 | } \
1005 | if ((dst) == NULL) { \
1006 | DECLTYPE_ASSIGN(dst, _elt); \
1007 | HASH_MAKE_TABLE(hh_dst, dst, _hs_oomed); \
1008 | IF_HASH_NONFATAL_OOM( \
1009 | if (_hs_oomed) { \
1010 | uthash_nonfatal_oom(_elt); \
1011 | (dst) = NULL; \
1012 | continue; \
1013 | } \
1014 | ) \
1015 | } else { \
1016 | _dst_hh->tbl = (dst)->hh_dst.tbl; \
1017 | } \
1018 | HASH_TO_BKT(_dst_hh->hashv, _dst_hh->tbl->num_buckets, _dst_bkt); \
1019 | HASH_ADD_TO_BKT(_dst_hh->tbl->buckets[_dst_bkt], hh_dst, _dst_hh, _hs_oomed); \
1020 | (dst)->hh_dst.tbl->num_items++; \
1021 | IF_HASH_NONFATAL_OOM( \
1022 | if (_hs_oomed) { \
1023 | HASH_ROLLBACK_BKT(hh_dst, dst, _dst_hh); \
1024 | HASH_DELETE_HH(hh_dst, dst, _dst_hh); \
1025 | _dst_hh->tbl = NULL; \
1026 | uthash_nonfatal_oom(_elt); \
1027 | continue; \
1028 | } \
1029 | ) \
1030 | HASH_BLOOM_ADD(_dst_hh->tbl, _dst_hh->hashv); \
1031 | _last_elt = _elt; \
1032 | _last_elt_hh = _dst_hh; \
1033 | } \
1034 | } \
1035 | } \
1036 | } \
1037 | HASH_FSCK(hh_dst, dst, "HASH_SELECT"); \
1038 | } while (0)
1039 |
1040 | #define HASH_CLEAR(hh,head) \
1041 | do { \
1042 | if ((head) != NULL) { \
1043 | HASH_BLOOM_FREE((head)->hh.tbl); \
1044 | uthash_free((head)->hh.tbl->buckets, \
1045 | (head)->hh.tbl->num_buckets*sizeof(struct UT_hash_bucket)); \
1046 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \
1047 | (head) = NULL; \
1048 | } \
1049 | } while (0)
1050 |
1051 | #define HASH_OVERHEAD(hh,head) \
1052 | (((head) != NULL) ? ( \
1053 | (size_t)(((head)->hh.tbl->num_items * sizeof(UT_hash_handle)) + \
1054 | ((head)->hh.tbl->num_buckets * sizeof(UT_hash_bucket)) + \
1055 | sizeof(UT_hash_table) + \
1056 | (HASH_BLOOM_BYTELEN))) : 0U)
1057 |
1058 | #ifdef NO_DECLTYPE
1059 | #define HASH_ITER(hh,head,el,tmp) \
1060 | for(((el)=(head)), ((*(char**)(&(tmp)))=(char*)((head!=NULL)?(head)->hh.next:NULL)); \
1061 | (el) != NULL; ((el)=(tmp)), ((*(char**)(&(tmp)))=(char*)((tmp!=NULL)?(tmp)->hh.next:NULL)))
1062 | #else
1063 | #define HASH_ITER(hh,head,el,tmp) \
1064 | for(((el)=(head)), ((tmp)=DECLTYPE(el)((head!=NULL)?(head)->hh.next:NULL)); \
1065 | (el) != NULL; ((el)=(tmp)), ((tmp)=DECLTYPE(el)((tmp!=NULL)?(tmp)->hh.next:NULL)))
1066 | #endif
1067 |
1068 | /* obtain a count of items in the hash */
1069 | #define HASH_COUNT(head) HASH_CNT(hh,head)
1070 | #define HASH_CNT(hh,head) ((head != NULL)?((head)->hh.tbl->num_items):0U)
1071 |
1072 | typedef struct UT_hash_bucket {
1073 | struct UT_hash_handle *hh_head;
1074 | unsigned count;
1075 |
1076 | /* expand_mult is normally set to 0. In this situation, the max chain length
1077 | * threshold is enforced at its default value, HASH_BKT_CAPACITY_THRESH. (If
1078 | * the bucket's chain exceeds this length, bucket expansion is triggered).
1079 | * However, setting expand_mult to a non-zero value delays bucket expansion
1080 | * (that would be triggered by additions to this particular bucket)
1081 | * until its chain length reaches a *multiple* of HASH_BKT_CAPACITY_THRESH.
1082 | * (The multiplier is simply expand_mult+1). The whole idea of this
1083 | * multiplier is to reduce bucket expansions, since they are expensive, in
1084 | * situations where we know that a particular bucket tends to be overused.
1085 | * It is better to let its chain length grow to a longer yet-still-bounded
1086 | * value, than to do an O(n) bucket expansion too often.
1087 | */
1088 | unsigned expand_mult;
1089 |
1090 | } UT_hash_bucket;
1091 |
1092 | /* random signature used only to find hash tables in external analysis */
1093 | #define HASH_SIGNATURE 0xa0111fe1u
1094 | #define HASH_BLOOM_SIGNATURE 0xb12220f2u
1095 |
1096 | typedef struct UT_hash_table {
1097 | UT_hash_bucket *buckets;
1098 | unsigned num_buckets, log2_num_buckets;
1099 | unsigned num_items;
1100 | struct UT_hash_handle *tail; /* tail hh in app order, for fast append */
1101 | ptrdiff_t hho; /* hash handle offset (byte pos of hash handle in element */
1102 |
1103 | /* in an ideal situation (all buckets used equally), no bucket would have
1104 | * more than ceil(#items/#buckets) items. that's the ideal chain length. */
1105 | unsigned ideal_chain_maxlen;
1106 |
1107 | /* nonideal_items is the number of items in the hash whose chain position
1108 | * exceeds the ideal chain maxlen. these items pay the penalty for an uneven
1109 | * hash distribution; reaching them in a chain traversal takes >ideal steps */
1110 | unsigned nonideal_items;
1111 |
1112 | /* ineffective expands occur when a bucket doubling was performed, but
1113 | * afterward, more than half the items in the hash had nonideal chain
1114 | * positions. If this happens on two consecutive expansions we inhibit any
1115 | * further expansion, as it's not helping; this happens when the hash
1116 | * function isn't a good fit for the key domain. When expansion is inhibited
1117 | * the hash will still work, albeit no longer in constant time. */
1118 | unsigned ineff_expands, noexpand;
1119 |
1120 | uint32_t signature; /* used only to find hash tables in external analysis */
1121 | #ifdef HASH_BLOOM
1122 | uint32_t bloom_sig; /* used only to test bloom exists in external analysis */
1123 | uint8_t *bloom_bv;
1124 | uint8_t bloom_nbits;
1125 | #endif
1126 |
1127 | } UT_hash_table;
1128 |
1129 | typedef struct UT_hash_handle {
1130 | struct UT_hash_table *tbl;
1131 | void *prev; /* prev element in app order */
1132 | void *next; /* next element in app order */
1133 | struct UT_hash_handle *hh_prev; /* previous hh in bucket order */
1134 | struct UT_hash_handle *hh_next; /* next hh in bucket order */
1135 | const void *key; /* ptr to enclosing struct's key */
1136 | unsigned keylen; /* enclosing struct's key len */
1137 | unsigned hashv; /* result of hash-fcn(key) */
1138 | } UT_hash_handle;
1139 |
1140 | #endif /* UTHASH_H */
1141 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from operator import itemgetter
5 | import torch.nn.utils
6 | from surel_gacc import sjoin
7 | from sklearn.metrics import roc_auc_score
8 | from torch.nn import BCEWithLogitsLoss
9 |
10 | from utils import *
11 |
12 |
13 | def train(model, opti, data, dT):
14 | model.train()
15 | total_loss = 0
16 | labels, preds = [], []
17 | for wl, wr, label, x in data:
18 | labels.append(label)
19 | Tf = torch.stack([dT[wl], dT[wr]])
20 | opti.zero_grad()
21 | pred = model(Tf, [wl, wr])
22 | preds.append(pred.detach().sigmoid())
23 | target = label.to(pred.device)
24 | loss = BCEWithLogitsLoss()(pred, target)
25 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
26 | loss.backward()
27 | opti.step()
28 | total_loss += loss.item() * len(label)
29 | predictions = torch.cat(preds).cpu()
30 | labels = torch.cat(labels)
31 | return total_loss / len(labels), roc_auc_score(labels, predictions)
32 |
33 |
34 | def eval_model(model, x_dict, x_set, args, evaluator, device, mode='test', return_predictions=False):
35 | model.eval()
36 | preds = []
37 | with torch.no_grad():
38 | x_embed, target = x_set['X'], x_set[mode]['E']
39 | with tqdm(total=len(target)) as pbar:
40 | for batch in gen_batch(target, args.batch_num, keep=True):
41 | Bs = torch.unique(batch).numpy()
42 | S, K, F = zip(*itemgetter(*Bs)(x_dict))
43 | S = torch.from_numpy(np.asarray(S)).long()
44 | F = np.concatenate(F)
45 | F = np.concatenate([[[0] * F.shape[-1]], F])
46 | mF = torch.from_numpy(F).to(device)
47 | uvw, uvx = sjoin(S, K, batch, return_idx=True)
48 | uvw = uvw.reshape(2, -1, 2)
49 | x = torch.from_numpy(uvw)
50 | gT = normalization(mF, args)
51 | gT = torch.stack([gT[uvw[0]], gT[uvw[1]]])
52 | pred = model(gT, x)
53 | preds.append(pred.sigmoid())
54 | pbar.update(len(pred))
55 | predictions = torch.cat(preds, dim=0)
56 |
57 | if not return_predictions:
58 | labels = torch.zeros(len(predictions))
59 | result_dict = {'metric': args.metric, 'mode': mode}
60 | if args.metric == 'mrr':
61 | num_pos = x_set[mode]['num_pos']
62 | labels[:num_pos] = 1
63 | pred_pos, pred_neg = predictions[:num_pos], predictions[num_pos:]
64 | result_dict['mrr_list'] = \
65 | evaluator.eval({"y_pred_pos": pred_pos.view(-1), "y_pred_neg": pred_neg.view(num_pos, -1)})['mrr_list']
66 | elif 'Hits' in args.metric:
67 | num_neg = x_set[mode]['num_neg']
68 | labels[num_neg:] = 1
69 | pred_neg, pred_pos = predictions[:num_neg], predictions[num_neg:]
70 | result_dict['hits'] = evaluate_hits(pred_pos.view(-1), pred_neg.view(-1), evaluator)
71 | result_dict['num_pos'] = len(pred_pos)
72 | else:
73 | raise NotImplementedError
74 |
75 | result_dict['auc'] = roc_auc_score(labels, predictions.cpu())
76 |
77 | return result_dict
78 | else:
79 | return predictions
80 |
81 |
82 | def eval_model_horder(model, x_dict, x_set, args, evaluator, device, mode='test', return_predictions=False):
83 | model.eval()
84 | preds = []
85 | with torch.no_grad():
86 | x_embed, target = x_set['X'], x_set[mode]['E']
87 | with tqdm(total=len(target)) as pbar:
88 | for batch in gen_batch(target, args.batch_num, keep=True):
89 | Bs = torch.unique(batch).numpy()
90 | S, K, F = zip(*itemgetter(*Bs)(x_dict))
91 | S = torch.from_numpy(np.asarray(S)).long()
92 | F = np.concatenate(F)
93 | F = np.concatenate([[[0] * F.shape[-1]], F])
94 | mF = torch.from_numpy(F).to(device)
95 | uw = sjoin(S, K, batch[:, [0, 2]], return_idx=False)
96 | vw = sjoin(S, K, batch[:, [1, 2]], return_idx=False)
97 | uvw = np.concatenate([uw, vw], axis=1).reshape(2, -1, 2)
98 | x = torch.from_numpy(uvw)
99 | gT = normalization(mF, args)
100 | gT = torch.stack([gT[uvw[0]], gT[uvw[1]]])
101 | pred = model(gT, x)
102 | preds.append(pred.sigmoid())
103 | pbar.update(len(pred))
104 | predictions = torch.cat(preds, dim=0)
105 |
106 | if not return_predictions:
107 | labels = torch.zeros(len(predictions))
108 | result_dict = {'metric': args.metric, 'mode': mode}
109 | if args.metric == 'mrr':
110 | num_pos = x_set[mode]['num_pos']
111 | labels[:num_pos] = 1
112 | pred_pos, pred_neg = predictions[:num_pos], predictions[num_pos:]
113 | result_dict['mrr_list'] = \
114 | evaluator.eval({"y_pred_pos": pred_pos.view(-1), "y_pred_neg": pred_neg.view(num_pos, -1)})['mrr_list']
115 | elif 'Hits' in args.metric:
116 | num_neg = x_set[mode]['num_neg']
117 | labels[num_neg:] = 1
118 | pred_neg, pred_pos = predictions[:num_neg], predictions[num_neg:]
119 | result_dict['hits'] = evaluate_hits(pred_pos.view(-1), pred_neg.view(-1), evaluator)
120 | result_dict['num_pos'] = len(pred_pos)
121 | else:
122 | raise NotImplementedError
123 |
124 | result_dict['auc'] = roc_auc_score(labels, predictions.cpu())
125 |
126 | return result_dict
127 | else:
128 | return predictions
129 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import random
6 | import numpy as np
7 | import torch
8 | from surel_gacc import run_walk
9 | from tqdm import tqdm
10 |
11 | def set_random_seed(args):
12 | seed = args.seed
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed_all(seed)
15 | torch.backends.cudnn.deterministic = True
16 | torch.backends.cudnn.benchmark = False
17 | np.random.seed(seed)
18 | random.seed(seed)
19 | os.environ['PYTHONHASHSEED'] = str(seed)
20 |
21 | def gen_batch(iterable, n=1, keep=False):
22 | length = len(iterable)
23 | if keep:
24 | for ndx in range(0, length, n):
25 | yield iterable[ndx:min(ndx + n, length)]
26 | else:
27 | for ndx in range(0, length - n, n):
28 | yield iterable[ndx:min(ndx + n, length)]
29 |
30 |
31 | def np_sampling(rw_dict, ptr, neighs, bsize, target, num_walks=100, num_steps=3):
32 | with tqdm(total=len(target)) as pbar:
33 | for batch in gen_batch(target, bsize, True):
34 | walk_set, freqs = run_walk(ptr, neighs, batch, num_walks=num_walks, num_steps=num_steps, replacement=True)
35 | node_id, node_freq = freqs[:, 0], freqs[:, 1]
36 | rw_dict.update(dict(zip(batch, zip(walk_set, node_id, node_freq))))
37 | pbar.update(len(batch))
38 | return rw_dict
39 |
40 |
41 | def sample(high: int, size: int, device=None):
42 | size = min(high, size)
43 | return torch.tensor(random.sample(range(high), size), device=device)
44 |
45 |
46 | def coarse(Tx, K):
47 | # repeat base as the length of unique nodes appeared in its set of walks
48 | xid = torch.from_numpy(np.arange(len(Tx)).repeat(list(map(len, K))))
49 | yid = np.concatenate(K)
50 | Ty = sorted(set(yid))
51 | num_nodes, base = len(Tx), len(Ty)
52 |
53 | # remapping root nodes
54 | xm = -torch.ones(max(Tx) + 1, dtype=torch.long)
55 | xm[Tx] = torch.arange(num_nodes)
56 | # remapping walk nodes
57 | ym = -torch.ones(max(yid) + 1, dtype=torch.long)
58 | ym[Ty] = torch.arange(base)
59 |
60 | mB = torch.zeros(num_nodes * base, dtype=torch.long)
61 | mB[xid * base + ym[yid]] = torch.arange(len(yid)) + 1
62 |
63 | return xm, ym, base, mB
64 |
65 |
66 | def gen_sample(S, Tx, K, pos_edges, full_edges, x_embed, args, gtype='Homogeneous'):
67 | unit_size = args.num_step * args.num_walk
68 | num_nodes = len(Tx)
69 | # for hetero graph
70 | if gtype != 'Homogeneous':
71 | Tx = sorted(Tx)
72 | Ws = torch.from_numpy(S).long()
73 | xm, ym, base, mB = coarse(Tx, K)
74 | xr = torch.tensor(Tx, dtype=torch.long)
75 |
76 | mA = torch.zeros([num_nodes, num_nodes], dtype=torch.long)
77 | row, col = xm[torch.cat([full_edges, full_edges[[1, 0]]], dim=-1)]
78 | mA[row, col] = 1
79 | if gtype == 'Homogeneous':
80 | mA = mA @ mA + mA
81 | # else:
82 | # mA = torch.zeros([num_nodes, num_nodes], dtype=torch.long)
83 | # row, col = xm[torch.cat([full_edges, full_edges[[1, 0]]], dim=-1)]
84 | # pivot = xm[full_edges[0].min()]
85 | # mA[row, col] = 1
86 | # mA[:pivot, :pivot] = 1
87 | # mA[pivot:, pivot:] = 1
88 | perm = torch.arange(num_nodes * num_nodes)[~ (mA.view(-1) > 0)]
89 | neg_pair = torch.vstack([torch.div(perm, num_nodes, rounding_mode='floor'), perm % num_nodes]).t()
90 | perms = sample(len(neg_pair), args.k * num_nodes)
91 | neg_edges = neg_pair[perms].t()
92 |
93 | edge_pairs = torch.cat([xm[pos_edges], neg_edges], dim=1)
94 | labels = torch.zeros(edge_pairs.shape[1])
95 | labels[:pos_edges.shape[1]] = 1
96 | idx = np.random.permutation(len(labels))
97 |
98 | batch_size = args.batch_size * (args.k + 1)
99 |
100 | for bidx in gen_batch(idx, batch_size, keep=True):
101 | batch = edge_pairs[:, bidx]
102 | uidx, vidx = batch
103 | ubase = (uidx * base).repeat_interleave(unit_size)
104 | vbase = (vidx * base).repeat_interleave(unit_size)
105 | Wu, Wv = Ws[uidx].view(-1), Ws[vidx].view(-1)
106 | u_offset, v_offset = ym[Wu], ym[Wv]
107 | wu = mB[torch.stack([ubase + u_offset, vbase + u_offset], dim=-1)]
108 | wv = mB[torch.stack([ubase + v_offset, vbase + v_offset], dim=-1)]
109 |
110 | if x_embed is not None:
111 | if args.use_degree or args.use_htype:
112 | yield wu, wv, labels[bidx], (x_embed[torch.stack([Wu, Wv])], x_embed[xr[batch]])
113 | else:
114 | yield wu, wv, labels[bidx], x_embed
115 | else:
116 | yield wu, wv, labels[bidx], None
117 |
118 |
119 | def gen_tuple(W, Tx, S, pos_tuple, args):
120 | unit_size = args.num_step * args.num_walk
121 | num_pos = len(pos_tuple)
122 | Ws = torch.from_numpy(W).long()
123 | xm, ym, base, mB = coarse(Tx, S)
124 |
125 | # do trivial random sampling
126 | dst_neg = torch.tensor([np.random.choice(Tx, args.k, replace=False) for _ in range(num_pos)])
127 | src_neg = pos_tuple[:, :2].repeat(1, args.k).view(-1, 2)
128 | neg_tuple = torch.cat([src_neg, dst_neg.view(-1, 1)], dim=1)
129 | neg_label = [i + num_pos for i, t in enumerate(neg_tuple) if torch.all(pos_tuple == t, dim=1).sum() > 0]
130 | tuples = xm[torch.cat([pos_tuple, neg_tuple]).t()]
131 | labels = torch.zeros(tuples.shape[1])
132 | labels[:num_pos] = 1
133 | labels[neg_label] = 1
134 | idx = np.random.permutation(len(labels))
135 |
136 | batch_size = args.batch_size * (args.k + 1)
137 |
138 | for bidx in gen_batch(idx, batch_size, keep=True):
139 | batch = tuples[:, bidx]
140 | uidx, vidx, widx = batch
141 | ubase = (uidx * base).repeat_interleave(unit_size)
142 | vbase = (vidx * base).repeat_interleave(unit_size)
143 | wbase = (widx * base).repeat_interleave(unit_size)
144 | Xu, Xv, Xw = Ws[uidx].view(-1), Ws[vidx].view(-1), Ws[widx].view(-1)
145 | u_offset, v_offset, w_offset = ym[Xu], ym[Xv], ym[Xw]
146 | wu = mB[torch.stack([ubase + u_offset, wbase + u_offset], dim=-1)]
147 | uw = mB[torch.stack([ubase + w_offset, wbase + w_offset], dim=-1)]
148 | wv = mB[torch.stack([vbase + v_offset, wbase + v_offset], dim=-1)]
149 | vw = mB[torch.stack([vbase + w_offset, wbase + w_offset], dim=-1)]
150 |
151 | yield torch.cat([wu, wv]), torch.cat([uw, vw]), labels[bidx], None
152 |
153 |
154 | def normalization(T, args):
155 | if args.use_weight:
156 | norm = torch.tensor([args.num_walk] * args.num_step + [args.w_max], device=T.device)
157 | else:
158 | if args.norm == 'all':
159 | norm = args.num_walk
160 | elif args.norm == 'root':
161 | norm = torch.tensor([args.num_walk] + [1] * args.num_step, device=T.device)
162 | else:
163 | raise NotImplementedError
164 | return T / norm
165 |
166 |
167 | # from https://github.com/facebookresearch/SEAL_OGB
168 | def get_pos_neg_edges(split, split_edge, ratio=1.0, keep_neg=False):
169 | if 'source_node' in split_edge['train']:
170 | source = split_edge[split]['source_node']
171 | target = split_edge[split]['target_node']
172 | target_neg = split_edge[split]['target_node_neg']
173 | # subsample
174 | np.random.seed(123)
175 | num_source = source.size(0)
176 | perm = np.random.permutation(num_source)
177 | perm = perm[:int(ratio * num_source)]
178 | source, target, target_neg = source[perm], target[perm], target_neg[perm, :]
179 | pos_edge = torch.stack([source, target])
180 | neg_per_target = target_neg.size(1)
181 | neg_edge = torch.stack([source.repeat_interleave(neg_per_target), target_neg.view(-1)])
182 | elif 'edge' in split_edge['train']:
183 | pos_edge = split_edge[split]['edge'].t()
184 | neg_edge = split_edge[split]['edge_neg'].t()
185 | # subsample for pos_edge
186 | if ratio < 1:
187 | np.random.seed(123)
188 | num_pos = pos_edge.size(1)
189 | perm = np.random.permutation(num_pos)
190 | perm = perm[:int(ratio * num_pos)]
191 | pos_edge = pos_edge[:, perm]
192 | # subsample for neg_edge
193 | if not keep_neg:
194 | np.random.seed(123)
195 | num_neg = neg_edge.size(1)
196 | if num_neg // num_pos == 1000:
197 | neg_edge = neg_edge.t().view(num_pos, -1, 2)[perm].reshape(-1, 2).t()
198 | else:
199 | perm = np.random.permutation(num_neg)
200 | perm = perm[:int(ratio * num_neg)]
201 | neg_edge = neg_edge[:, perm]
202 | elif 'hedge' in split_edge['train']:
203 | pos_edge = split_edge[split]['hedge'].t()
204 | neg_edge = split_edge[split]['hedge_neg']
205 | if ratio < 1:
206 | np.random.seed(123)
207 | num_pos = pos_edge.size(1)
208 | perm = np.random.permutation(num_pos)
209 | perm = perm[:int(ratio * num_pos)]
210 | pos_edge = pos_edge[:, perm]
211 | neg_edge = neg_edge.view(num_pos, -1, 3)[perm].reshape(-1, 3)
212 | pos_edge = pos_edge.t()
213 | else:
214 | raise NotImplementedError
215 | return pos_edge, neg_edge
216 |
217 |
218 | def evaluate_hits(pos_pred, neg_pred, evaluator):
219 | results = {}
220 | for K in [10, 20, 50, 100]:
221 | evaluator.K = K
222 | res_hits = evaluator.eval({
223 | 'y_pred_pos': pos_pred,
224 | 'y_pred_neg': neg_pred,
225 | })[f'hits@{K}']
226 |
227 | results[f'Hits@{K}'] = res_hits
228 | return results
229 |
230 |
231 | def save_checkpoint(state, filename='checkpoint'):
232 | print("=> Saving checkpoint")
233 | torch.save(state, f'{filename}.pth.tar')
234 |
235 |
236 | def load_checkpoint(model, optimizer, filename):
237 | checkpoint = torch.load(f'{filename}.pth.tar')
238 | print(f"<= Loading checkpoint from epoch {checkpoint['epoch']}")
239 | model.load_state_dict(checkpoint['state_dict'])
240 | optimizer.load_state_dict(checkpoint['optimizer'])
241 |
--------------------------------------------------------------------------------