├── .gitignore ├── LICENSE ├── README.md ├── dataloaders.py ├── environment.yml ├── log.py ├── main.py ├── main_hetro.py ├── main_horder.py ├── models ├── layer.py ├── model.py └── model_horder.py ├── subg_acc ├── .gitignore ├── LICENSE ├── README.md ├── graph_acc.c ├── setup.py └── uthash.h ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | .DS_Store 3 | .idea/* 4 | __pycache__/* 5 | notebooks/.ipynb_checkpoints/* 6 | log/* 7 | models/__pycacche__/* 8 | *pyc 9 | 10 | # Byte-complied / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022-Present, Haoteng Yin and GCoM@Purdue 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

SUREL: Subgraph-based Graph Representation Learning Framework

2 |

3 | 4 | Github 5 | 6 | OGBL 7 | Version 8 |

9 | 10 | SUREL is a novel walk-based computation framework for efficient large-scale subgraph-base graph representation learning (SGRL). Details on how SUREL works can be found in our VLDB'22 paper [Algorithm and System Co-design for Efficient Subgraph-based Graph Representation Learning](https://arxiv.org/pdf/2202.13538.pdf). 11 | 12 | Currently, we support: 13 | - Large-scale graph ML tasks: link prediction / relation-type prediction / higher-order pattern prediction 14 | - Preprocessing and training of datasets in OGB format 15 | - Python API ([SubG Library](https://github.com/VeritasYin/subg_acc)) for subgraph sampling and joining procedures 16 | - Single GPU training and evaluation 17 | - Structural (Relative Position) Encoding + Node Features 18 | - [VesselGraph](https://paperswithcode.com/dataset/vesselgraph) Dataset 19 | 20 | We are working on expanding the functionality of SUREL to include: 21 | - Multi-GPU training 22 | 23 | ## Requirements ## 24 | (Other versions may work, but are untested) 25 | * Ubuntu 20.04 26 | * CUDA >= 10.2 27 | * python >= 3.8 28 | * 1.8 <= pytorch <= 1.12 29 | 30 | ## Datasets 31 | 32 | SGRL datasets (`mag-write (P-A)`, `mag-cite (P-P)`, `tags-math`, `DBLP-coauthor`) for relation and higher-order prediction can be accessed via [Zenodo](https://zenodo.org/records/15186012). 33 | 34 | ## SGRL Environment Setup ## 35 | 36 | Requirements: Python >= 3.8, [Anaconda3](https://www.anaconda.com/) 37 | 38 | - Update conda: 39 | ```bash 40 | conda update -n base -c defaults conda 41 | ``` 42 | 43 | - Install basic dependencies to virtual environment and activate it: 44 | ```bash 45 | conda env create -f environment.yml 46 | conda activate sgrl-env 47 | ``` 48 | - **SUREL** now support PyTorch 1.12.1 and PyG 2.2.0. To install them, simply run 49 | ```bash 50 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 51 | conda install pyg -c pyg 52 | ``` 53 | For more details, please refer to the [PyTorch](https://pytorch.org/) and [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html). 54 | The code of this repository is lately tested with Python 3.10.9 + PyTorch 1.12.1 (CUDA 11.3) + torch-geometric 2.2.0. 55 | 56 | - Example commends of installation for PyTorch 1.8.0 (CUDA 10.2) and torch-geometric 1.6.3: 57 | ```bash 58 | conda install pytorch==1.8.0 torchvision torchaudio cudatoolkit=10.2 59 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html 60 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html 61 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html 62 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html 63 | pip install torch-geometric==1.6.3 64 | ``` 65 | 66 | ## Quick Start 67 | 68 | 1. Install required version of PyTorch that is compatible with your CUDA driver 69 | 70 | 2. Clone the repository `git clone https://github.com/Graph-COM/SUREL` 71 | 72 | 3. Build and install the [SubG](https://github.com/Graph-COM/SUREL/tree/main/subg_acc) library (v1.1) `cd subg_acc;python3 setup.py install` 73 | 74 | - To train **SUREL** for link prediction on Collab: 75 | ```bash 76 | python main.py --dataset ogbl-collab --metric hit --num_step 4 --num_walk 200 --use_val 77 | ``` 78 | 79 | - To train **SUREL** for link prediction on Citation2: 80 | ```bash 81 | python main.py --dataset ogbl-citation2 --metric mrr --num_step 4 --num_walk 100 82 | ``` 83 | 84 | - To train **SUREL** for relation prediction on MAG(A-P): 85 | ```bash 86 | python main_hetro.py --dataset mag --relation write --metric mrr --num_step 3 --num_walk 100 --k 10 87 | ``` 88 | 89 | - To train **SUREL** for higher-order prediction on DBLP: 90 | ```bash 91 | python main_horder.py --dataset DBLP-coauthor --metric mrr --num_step 3 --num_walk 100 92 | ``` 93 | 94 | - All detailed training logs can be found at `//.log`. 95 | 96 | ## Result Reproduction 97 | This section supplements our SUREL paper accepted in VLDB'22. To reproduce the results of SUREL reported in Tables 3 and 4, use the following command: 98 | * OGBL - Link Prediction 99 | ```bash 100 | python3 main.py --dataset --metric --num_step --num_walk --k 101 | ``` 102 | where `dataset` can be either of `ogbl-citation2`, `ogbl-collab` and `ogbl-ppa`; `metric` can be either `mrr` or `hit`. 103 | * Relation Type Prediction 104 | ```bash 105 | python main_hetro.py --dataset mag --relation --metric mrr --num_step --num_walk --k 106 | ``` 107 | where `relation` can be either `write` or `cite`. 108 | * Higher-order Pattern Prediction 109 | ```bash 110 | python main_horder.py --dataset --metric mrr --num_step --num_walk --k 111 | ``` 112 | where `dataset` can be either `DBLP-coauthor` or `tags-math`. 113 | 114 | The detailed parameter configurations are provided in Table 8, Appendix D of the [arxiv version](https://arxiv.org/abs/2202.13538) of this work. For the profiling of SUREL in Table 4 and Fig. 4 (a-b), please use the parameter setting provided in Appendix D.3. 115 | 116 | To test the scaling performance of Walk Sampler and RPE Joining, functions 'run_walk' and 'sjoin' can be imported and called from the module `surel_gacc`. Please adjust the parameter values of `num_walk`, `num_step` and `nthread` accordingly as Fig. 4 (c-d) shown. 117 | 118 | To perform hyper-parameter analysis of the number of walks 𝑀, the step of walks 𝑚, and the hidden dimension 𝑑, please adjust the parameter values of `num_walk`, `num_step` and `hidden_dim` accordingly as Fig. 5 shown. 119 | 120 |
121 | Sample Output 122 | 123 | ```text 124 | 2022-03-25 15:57:16,677 - root - INFO - Create log file at ./log/ogbl-citation2/032522_155716.log 125 | 2022-03-25 15:57:16,677 - root - INFO - Command line executed: python main.py --gpu 2 --patience 5 --hidden_dim 64 --seed 0 126 | 2022-03-25 15:57:16,677 - root - INFO - Full args parsed: 127 | 2022-03-25 15:57:16,677 - root - INFO - Namespace(B_size=1500, batch_num=2000, batch_size=32, data_usage=1.0, dataset='ogbl-citation2', debug=False, directed=False, dropout=0.1, eval_steps=100, gpu_id=2, hidden_dim=64, k=50, l2=0.0, layers=2, load_dict=False, load_model=False, log_dir='./log/', lr=0.001, memo=None, metric='mrr', model='RNN', norm='all', nthread=16, num_step=4, num_walk=100, optim='adam', patience=5, repeat=1, res_dir='./dataset/save', rtest=499, save=False, seed=0, stamp='032522_155716', summary_file='result_summary.log', test_ratio=1.0, train_ratio=0.05, use_degree=False, use_feature=False, use_htype=False, use_val=False, use_weight=False, valid_ratio=0.1, x_dim=0) 128 | 2022-03-25 15:57:16,727 - root - INFO - torch num_threads 16 129 | 2022-03-25 15:57:26,536 - root - INFO - eval metric mrr 130 | task type link prediction 131 | download_name citation-v2 132 | version 1 133 | url http://snap.stanford.edu/ogb/data/linkproppred... 134 | add_inverse_edge False 135 | has_node_attr True 136 | has_edge_attr False 137 | split time 138 | additional node files node_year 139 | additional edge files None 140 | is hetero False 141 | binary False 142 | Name: ogbl-citation2, dtype: object 143 | Keys: ['x', 'edge_index', 'node_year'] 144 | 2022-03-25 15:57:26,536 - root - INFO - node size 2927963, feature dim 128, edge size 30387995 with mask ratio 0.05 145 | 2022-03-25 15:57:26,536 - root - INFO - use_weight False, use_coalesce False, use_degree False, use_val False 146 | 2022-03-25 15:57:45,775 - root - INFO - Sparsity of loaded graph 6.727197221716796e-06 147 | 2022-03-25 15:57:45,782 - root - INFO - Observed subgraph with 2918932 nodes and 28836021 edges; 148 | 2022-03-25 15:57:45,789 - root - INFO - Training subgraph with 1394162 nodes and 1519315 edges. 149 | 2022-03-25 15:57:50,400 - root - INFO - #Model Params 79617 150 | 2022-03-25 15:59:14,643 - root - INFO - Samples: valid 8659 by 1000 test 86596 by 1000 metric: mrr 151 | 2022-03-25 15:59:15,405 - root - INFO - Running Round 1 152 | 2022-03-25 15:59:29,229 - root - INFO - Batch 1 W1502/D1394162 Loss: 0.1971, AUC: 0.5049 153 | 2022-03-25 15:59:42,266 - root - INFO - Batch 2 W2991/D1394162 Loss: 0.1097, AUC: 0.4975 154 | 2022-03-25 15:59:56,187 - root - INFO - Batch 3 W4431/D1394162 Loss: 0.1024, AUC: 0.4976 155 | 2022-03-25 16:00:09,070 - root - INFO - Batch 4 W5761/D1394162 Loss: 0.1030, AUC: 0.4980 156 | 2022-03-25 16:00:23,285 - root - INFO - Batch 5 W7215/D1394162 Loss: 0.1013, AUC: 0.5053 157 | ... 158 | ``` 159 |
160 | 161 | ## Usage 162 | ``` 163 | usage: Interface for SUREL framework [-h] 164 | [--dataset {ogbl-ppa,ogbl-citation2,ogbl-collab,mag,DBLP-coauthor,tags-math}] 165 | [--model {RNN,MLP,Transformer,GNN}] 166 | [--layers LAYERS] 167 | [--hidden_dim HIDDEN_DIM] [--x_dim X_DIM] 168 | [--data_usage DATA_USAGE] 169 | [--train_ratio TRAIN_RATIO] 170 | [--valid_ratio VALID_RATIO] 171 | [--test_ratio TEST_RATIO] 172 | [--metric {auc,mrr,hit}] [--seed SEED] 173 | [--gpu_id GPU_ID] [--nthread NTHREAD] 174 | [--B_size B_SIZE] [--num_walk NUM_WALK] 175 | [--num_step NUM_STEP] [--k K] 176 | [--directed DIRECTED] [--use_feature] 177 | [--use_weight] [--use_degree] 178 | [--use_htype] [--use_val] [--norm NORM] 179 | [--optim OPTIM] [--rtest RTEST] 180 | [--eval_steps EVAL_STEPS] 181 | [--batch_size BATCH_SIZE] 182 | [--batch_num BATCH_NUM] [--lr LR] 183 | [--dropout DROPOUT] [--l2 L2] 184 | [--patience PATIENCE] [--repeat REPEAT] 185 | [--log_dir LOG_DIR] [--res_dir RES_DIR] 186 | [--stamp STAMP] 187 | [--summary_file SUMMARY_FILE] [--debug] 188 | [--abs] [--save] [--load_dict] 189 | [--load_model] [--memo MEMO] 190 | ``` 191 | 192 |
193 | Optional Arguments 194 | 195 | ``` 196 | optional arguments: 197 | -h, --help show this help message and exit 198 | --dataset {mag} dataset name 199 | --relation {write,cite} 200 | relation type 201 | --model {RNN,MLP,Transformer,GNN} 202 | base model to use 203 | --layers LAYERS number of layers 204 | --hidden_dim HIDDEN_DIM 205 | hidden dimension 206 | --x_dim X_DIM dim of raw node features 207 | --data_usage DATA_USAGE 208 | use partial dataset 209 | --train_ratio TRAIN_RATIO 210 | mask partial edges for training 211 | --valid_ratio VALID_RATIO 212 | use partial valid set 213 | --test_ratio TEST_RATIO 214 | use partial test set 215 | --metric {auc,mrr,hit} 216 | metric for evaluating performance 217 | --seed SEED seed to initialize all the random modules 218 | --gpu_id GPU_ID gpu id 219 | --nthread NTHREAD number of thread 220 | --B_size B_SIZE set size of train sampling 221 | --num_walk NUM_WALK total number of random walks 222 | --num_step NUM_STEP total steps of random walk 223 | --k K number of paired negative queries 224 | --directed DIRECTED whether to treat the graph as directed 225 | --use_feature whether to use raw features as input 226 | --use_weight whether to use edge weight as input 227 | --use_degree whether to use node degree as input 228 | --use_htype whether to use node type as input 229 | --use_val whether to use val as input 230 | --norm NORM method of normalization 231 | --optim OPTIM optimizer to use 232 | --rtest RTEST step start to test 233 | --eval_steps EVAL_STEPS 234 | number of steps to test 235 | --batch_size BATCH_SIZE 236 | mini-batch size (train) 237 | --batch_num BATCH_NUM 238 | mini-batch size (test) 239 | --lr LR learning rate 240 | --dropout DROPOUT dropout rate 241 | --l2 L2 l2 regularization (weight decay) 242 | --patience PATIENCE early stopping steps 243 | --repeat REPEAT number of training instances to repeat 244 | --log_dir LOG_DIR log directory 245 | --res_dir RES_DIR resource directory 246 | --stamp STAMP time stamp 247 | --summary_file SUMMARY_FILE 248 | brief summary of training results 249 | --debug whether to use debug mode 250 | --save whether to save RPE to files 251 | --load_dict whether to load RPE from files 252 | --load_model whether to load saved model from files 253 | --memo MEMO notes 254 | ``` 255 |
256 | 257 | ## Citation 258 | Please cite our paper if you are interested in our work. 259 | ``` 260 | @article{yin2022algorithm, 261 | title={Algorithm and System Co-design for Efficient Subgraph-based Graph Representation Learning}, 262 | author={Yin, Haoteng and Zhang, Muhan and Wang, Yanbang and Wang, Jianguo and Li, Pan}, 263 | journal={Proceedings of the VLDB Endowment}, 264 | volume={15}, 265 | number={11}, 266 | pages={2788-2796}, 267 | year={2022} 268 | } 269 | ``` 270 | -------------------------------------------------------------------------------- /dataloaders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | from ogb.linkproppred import PygLinkPropPredDataset 9 | from scipy.sparse import csr_matrix 10 | from torch_sparse import coalesce 11 | from tqdm import tqdm 12 | 13 | from utils import get_pos_neg_edges, np_sampling 14 | 15 | 16 | class DEDataset(): 17 | def __init__(self, dataset, mask_ratio=0.05, use_weight=False, use_coalesce=False, use_degree=False, 18 | use_val=False): 19 | self.data = PygLinkPropPredDataset(name=dataset) 20 | self.graph = self.data[0] 21 | self.split_edge = self.data.get_edge_split() 22 | self.mask_ratio = mask_ratio 23 | self.use_degree = use_degree 24 | self.use_weight = (use_weight and 'edge_weight' in self.graph) 25 | self.use_coalesce = use_coalesce 26 | self.use_val = use_val 27 | self.gtype = 'Homogeneous' 28 | 29 | if 'x' in self.graph: 30 | self.num_nodes, self.num_feature = self.graph['x'].shape 31 | else: 32 | self.num_nodes, self.num_feature = len(torch.unique(self.graph['edge_index'])), None 33 | 34 | if 'source_node' in self.split_edge['train']: 35 | self.directed = True 36 | self.train_edge = self.graph['edge_index'].t() 37 | else: 38 | self.directed = False 39 | self.train_edge = self.split_edge['train']['edge'] 40 | 41 | if use_weight: 42 | self.train_weight = self.split_edge['train']['weight'] 43 | if use_coalesce: 44 | train_edge_col, self.train_weight = coalesce(self.train_edge.t(), self.train_weight, self.num_nodes, 45 | self.num_nodes) 46 | self.train_edge = train_edge_col.t() 47 | self.train_wmax = max(self.train_weight) 48 | else: 49 | self.train_weight = None 50 | # must put after coalesce 51 | self.len_train = self.train_edge.shape[0] 52 | 53 | def process(self, logger): 54 | logger.info(f'{self.data.meta_info}\nKeys: {self.graph.keys}') 55 | logger.info( 56 | f'node size {self.num_nodes}, feature dim {self.num_feature}, edge size {self.len_train} with mask ratio {self.mask_ratio}') 57 | logger.info( 58 | f'use_weight {self.use_weight}, use_coalesce {self.use_coalesce}, use_degree {self.use_degree}, use_val {self.use_val}') 59 | 60 | self.num_pos = int(self.len_train * self.mask_ratio) 61 | idx = np.random.permutation(self.len_train) 62 | # pos sample edges masked for training, observed edges for structural features 63 | self.pos_edge, obsrv_edge = self.train_edge[idx[:self.num_pos]], self.train_edge[idx[self.num_pos:]] 64 | val_edge = self.train_edge 65 | self.val_nodes = torch.unique(self.train_edge).tolist() 66 | 67 | if self.use_weight: 68 | pos_e_weight = self.train_weight[idx[:self.num_pos]] 69 | obsrv_e_weight = self.train_weight[idx[self.num_pos:]] 70 | val_e_weight = self.train_weight 71 | else: 72 | pos_e_weight = np.ones(self.num_pos, dtype=int) 73 | obsrv_e_weight = np.ones(self.len_train - self.num_pos, dtype=int) 74 | val_e_weight = np.ones(self.len_train, dtype=int) 75 | 76 | if self.use_val: 77 | # collab allows using valid edges for training 78 | obsrv_edge = torch.cat([obsrv_edge, self.split_edge['valid']['edge']]) 79 | full_edge = torch.cat([self.train_edge, self.split_edge['valid']['edge']], dim=0) 80 | self.test_nodes = torch.unique(full_edge).tolist() 81 | if self.use_weight: 82 | obsrv_e_weight = torch.cat([self.train_weight[idx[self.num_pos:]], self.split_edge['valid']['weight']]) 83 | full_e_weight = torch.cat([self.train_weight, self.split_edge['valid']['weight']], dim=0) 84 | if self.use_coalesce: 85 | obsrv_edge_col, obsrv_e_weight = coalesce(obsrv_edge.t(), obsrv_e_weight, self.num_nodes, 86 | self.num_nodes) 87 | obsrv_edge = obsrv_edge_col.t() 88 | full_edge_col, full_e_weight = coalesce(full_edge.t(), full_e_weight, self.num_nodes, 89 | self.num_nodes) 90 | full_edge = full_edge_col.t() 91 | self.full_wmax = max(full_e_weight) 92 | else: 93 | obsrv_e_weight = np.ones(obsrv_edge.shape[0], dtype=int) 94 | full_e_weight = np.ones(full_edge.shape[0], dtype=int) 95 | else: 96 | full_edge, full_e_weight = self.train_edge, self.train_weight 97 | self.test_nodes = self.val_nodes 98 | 99 | # load observed graph and save as a CSR sparse matrix 100 | max_obsrv_idx = torch.max(obsrv_edge).item() 101 | net_obsrv = csr_matrix((obsrv_e_weight, (obsrv_edge[:, 0].numpy(), obsrv_edge[:, 1].numpy())), 102 | shape=(max_obsrv_idx + 1, max_obsrv_idx + 1)) 103 | G_obsrv = net_obsrv + net_obsrv.T 104 | assert sum(G_obsrv.diagonal()) == 0 105 | 106 | # subgraph for training(5 % edges, pos edges) 107 | max_pos_idx = torch.max(self.pos_edge).item() 108 | net_pos = csr_matrix((pos_e_weight, (self.pos_edge[:, 0].numpy(), self.pos_edge[:, 1].numpy())), 109 | shape=(max_pos_idx + 1, max_pos_idx + 1)) 110 | G_pos = net_pos + net_pos.T 111 | assert sum(G_pos.diagonal()) == 0 112 | 113 | max_val_idx = torch.max(val_edge).item() 114 | net_val = csr_matrix((val_e_weight, (val_edge[:, 0].numpy(), val_edge[:, 1].numpy())), 115 | shape=(max_val_idx + 1, max_val_idx + 1)) 116 | G_val = net_val + net_val.T 117 | assert sum(G_val.diagonal()) == 0 118 | 119 | if self.use_val: 120 | max_full_idx = torch.max(full_edge).item() 121 | net_full = csr_matrix((full_e_weight, (full_edge[:, 0].numpy(), full_edge[:, 1].numpy())), 122 | shape=(max_full_idx + 1, max_full_idx + 1)) 123 | G_full = net_full + net_full.transpose() 124 | assert sum(G_full.diagonal()) == 0 125 | else: 126 | G_full = G_val 127 | 128 | self.degree = np.expand_dims(np.log(G_full.getnnz(axis=1) + 1), 1).astype( 129 | np.float32) if self.use_degree else None 130 | 131 | # sparsity of graph 132 | logger.info(f'Sparsity of loaded graph {G_obsrv.getnnz() / (max_obsrv_idx + 1) ** 2}') 133 | # statistic of graph 134 | logger.info( 135 | f'Observed subgraph with {np.sum(G_obsrv.getnnz(axis=1) > 0)} nodes and {int(G_obsrv.nnz / 2)} edges;') 136 | logger.info(f'Training subgraph with {np.sum(G_pos.getnnz(axis=1) > 0)} nodes and {int(G_pos.nnz / 2)} edges.') 137 | 138 | self.data, self.graph = None, None 139 | 140 | return {'pos': G_pos, 'train': G_obsrv, 'val': G_val, 'test': G_full} 141 | 142 | 143 | class DE_Hetro_Dataset(): 144 | def __init__(self, dataset, relation, mask_ratio=0.05): 145 | self.data = torch.load(f'./dataset/{dataset}_{relation}.pl') 146 | self.split_edge = self.data['split_edge'] 147 | self.node_type = list(self.data['num_nodes_dict']) 148 | self.mask_ratio = mask_ratio 149 | rel_key = ('author', 'writes', 'paper') if relation == 'cite' else ('paper', 'cites', 'paper') 150 | self.obsrv_edge = self.data['edge_index'][rel_key] 151 | self.split_edge = self.data['split_edge'] 152 | self.gtype = 'Heterogeneous' if relation == 'cite' else 'Homogeneous' 153 | 154 | if 'x' in self.data: 155 | self.num_nodes, self.num_feature = self.data['x'].shape 156 | else: 157 | self.num_nodes, self.num_feature = self.obsrv_edge.unique().size(0), None 158 | 159 | if 'source_node' in self.split_edge['train']: 160 | self.directed = True 161 | self.train_edge = self.graph['edge_index'].t() 162 | else: 163 | self.directed = False 164 | self.train_edge = self.split_edge['train']['edge'] 165 | 166 | self.len_train = self.train_edge.shape[0] 167 | 168 | def process(self, logger): 169 | logger.info( 170 | f'node size {self.num_nodes}, feature dim {self.num_feature}, edge size {self.len_train} with mask ratio {self.mask_ratio}') 171 | 172 | self.num_pos = int(self.len_train * self.mask_ratio) 173 | idx = np.random.permutation(self.len_train) 174 | # pos sample edges masked for training, observed edges for structural features 175 | self.pos_edge, obsrv_edge = self.train_edge[idx[:self.num_pos]], torch.cat( 176 | [self.train_edge[idx[self.num_pos:]], self.obsrv_edge]) 177 | val_edge = torch.cat([self.train_edge, self.obsrv_edge]) 178 | len_redge = len(self.obsrv_edge) 179 | 180 | pos_e_weight = np.ones(self.num_pos, dtype=int) 181 | obsrv_e_weight = np.ones(self.len_train - self.num_pos + len_redge, dtype=int) 182 | val_e_weight = np.ones(self.len_train + len_redge, dtype=int) 183 | 184 | # load observed graph and save as a CSR sparse matrix 185 | max_obsrv_idx = torch.max(obsrv_edge).item() 186 | net_obsrv = csr_matrix((obsrv_e_weight, (obsrv_edge[:, 0].numpy(), obsrv_edge[:, 1].numpy())), 187 | shape=(max_obsrv_idx + 1, max_obsrv_idx + 1)) 188 | G_obsrv = net_obsrv + net_obsrv.T 189 | assert sum(G_obsrv.diagonal()) == 0 190 | 191 | # subgraph for training(5 % edges, pos edges) 192 | max_pos_idx = torch.max(self.pos_edge).item() 193 | net_pos = csr_matrix((pos_e_weight, (self.pos_edge[:, 0].numpy(), self.pos_edge[:, 1].numpy())), 194 | shape=(max_pos_idx + 1, max_pos_idx + 1)) 195 | G_pos = net_pos + net_pos.T 196 | assert sum(G_pos.diagonal()) == 0 197 | 198 | max_val_idx = torch.max(val_edge).item() 199 | net_val = csr_matrix((val_e_weight, (val_edge[:, 0].numpy(), val_edge[:, 1].numpy())), 200 | shape=(max_val_idx + 1, max_val_idx + 1)) 201 | G_val = net_val + net_val.T 202 | assert sum(G_val.diagonal()) == 0 203 | 204 | G_full = G_val 205 | # sparsity of graph 206 | logger.info(f'Sparsity of loaded graph {G_obsrv.getnnz() / (max_obsrv_idx + 1) ** 2}') 207 | # statistic of graph 208 | logger.info( 209 | f'Observed subgraph with {np.sum(G_obsrv.getnnz(axis=1) > 0)} nodes and {int(G_obsrv.nnz / 2)} edges;') 210 | logger.info(f'Training subgraph with {np.sum(G_pos.getnnz(axis=1) > 0)} nodes and {int(G_pos.nnz / 2)} edges.') 211 | 212 | self.data = None 213 | return {'pos': G_pos, 'train': G_obsrv, 'val': G_val, 'test': G_full} 214 | 215 | 216 | class DE_Hyper_Dataset(): 217 | def __init__(self, dataset, mask_ratio=0.6): 218 | self.data = torch.load(f'./dataset/{dataset}.pl') 219 | self.obsrv_edge = torch.from_numpy(self.data['edge_index']) 220 | self.num_tup = len(self.data['triplets']) 221 | self.mask_ratio = mask_ratio 222 | self.split_edge = self.data['triplets'] 223 | self.gtype = 'Hypergraph' 224 | 225 | if 'x' in self.data: 226 | self.num_nodes, self.num_feature = self.data['x'].shape 227 | else: 228 | self.num_nodes, self.num_feature = self.obsrv_edge.unique().size(0), None 229 | 230 | def get_edge_split(self, ratio, k=1000, seed=2021): 231 | np.random.seed(seed) 232 | tuples = torch.from_numpy(self.data['triplets']) 233 | idx = np.random.permutation(self.num_tup) 234 | num_train = int(ratio * self.num_tup) 235 | split_idx = {'train': {'hedge': tuples[idx[:num_train]]}} 236 | val_idx, test_idx = np.split(idx[num_train:], 2) 237 | split_idx['valid'], split_idx['test'] = {'hedge': tuples[val_idx]}, {'hedge': tuples[test_idx]} 238 | node_neg = torch.randint(torch.max(tuples), (len(val_idx), k)) 239 | split_idx['valid']['hedge_neg'] = torch.cat( 240 | [split_idx['valid']['hedge'][:, :2].repeat(1, k).view(-1, 2).t(), node_neg.view(1, -1)]).t() 241 | split_idx['test']['hedge_neg'] = torch.cat( 242 | [split_idx['test']['hedge'][:, :2].repeat(1, k).view(-1, 2).t(), node_neg.view(1, -1)]).t() 243 | return split_idx 244 | 245 | def process(self, logger): 246 | logger.info( 247 | f'node size {self.num_nodes}, feature dim {self.num_feature}, edge size {self.num_tup} with mask ratio {self.mask_ratio}') 248 | obsrv_edge = self.obsrv_edge 249 | 250 | # load observed graph and save as a CSR sparse matrix 251 | max_obsrv_idx = torch.max(obsrv_edge).item() 252 | obsrv_e_weight = np.ones(len(obsrv_edge), dtype=int) 253 | net_obsrv = csr_matrix((obsrv_e_weight, (obsrv_edge[:, 0].numpy(), obsrv_edge[:, 1].numpy())), 254 | shape=(max_obsrv_idx + 1, max_obsrv_idx + 1)) 255 | G_enc = net_obsrv + net_obsrv.T 256 | assert sum(G_enc.diagonal()) == 0 257 | 258 | # sparsity of graph 259 | logger.info(f'Sparsity of loaded graph {G_enc.getnnz() / (max_obsrv_idx + 1) ** 2}') 260 | # statistic of graph 261 | logger.info(f'Observed subgraph with {np.sum(G_enc.getnnz(axis=1) > 0)} nodes and {int(G_enc.nnz / 2)} edges;') 262 | 263 | return G_enc 264 | 265 | 266 | def gen_dataset(dataset, graphs, args, bsize=10000): 267 | G_val, G_full = graphs['val'], graphs['test'] 268 | 269 | keep_neg = False if 'ppa' not in args.dataset else True 270 | 271 | test_pos_edge, test_neg_edge = get_pos_neg_edges('test', dataset.split_edge, ratio=args.test_ratio, 272 | keep_neg=keep_neg) 273 | val_pos_edge, val_neg_edge = get_pos_neg_edges('valid', dataset.split_edge, ratio=args.valid_ratio, 274 | keep_neg=keep_neg) 275 | 276 | inf_set = {'test': {}, 'val': {}} 277 | 278 | if args.metric == 'mrr': 279 | inf_set['test']['E'] = torch.cat([test_pos_edge, test_neg_edge], dim=1).t() 280 | inf_set['val']['E'] = torch.cat([val_pos_edge, val_neg_edge], dim=1).t() 281 | inf_set['test']['num_pos'], inf_set['val']['num_pos'] = test_pos_edge.shape[1], val_pos_edge.shape[1] 282 | inf_set['test']['num_neg'], inf_set['val']['num_neg'] = test_neg_edge.shape[1] // inf_set['test']['num_pos'], \ 283 | val_neg_edge.shape[1] // inf_set['val']['num_pos'] 284 | elif 'Hit' in args.metric: 285 | inf_set['test']['E'] = torch.cat([test_neg_edge, test_pos_edge], dim=1).t() 286 | inf_set['val']['E'] = torch.cat([val_neg_edge, val_pos_edge], dim=1).t() 287 | inf_set['test']['num_pos'], inf_set['val']['num_pos'] = test_pos_edge.shape[1], val_pos_edge.shape[1] 288 | inf_set['test']['num_neg'], inf_set['val']['num_neg'] = test_neg_edge.shape[1], val_neg_edge.shape[1] 289 | else: 290 | raise NotImplementedError 291 | 292 | if args.use_val: 293 | val_dict = np_sampling({}, G_val.indptr, G_val.indices, bsize=bsize, 294 | target=torch.unique(inf_set['val']['E']).tolist(), num_walks=args.num_walk, 295 | num_steps=args.num_step - 1) 296 | test_dict = np_sampling({}, G_full.indptr, G_full.indices, bsize=bsize, 297 | target=torch.unique(inf_set['test']['E']).tolist(), num_walks=args.num_walk, 298 | num_steps=args.num_step - 1) 299 | else: 300 | val_dict = test_dict = np_sampling({}, G_val.indptr, G_val.indices, bsize=bsize, 301 | target=torch.unique( 302 | torch.cat([inf_set['val']['E'], inf_set['test']['E']])).tolist(), 303 | num_walks=args.num_walk, num_steps=args.num_step - 1) 304 | 305 | if not args.use_feature: 306 | if args.use_degree: 307 | inf_set['X'] = torch.from_numpy(dataset.degree) 308 | elif args.use_htype: 309 | inf_set['X'] = dataset.node_map 310 | else: 311 | inf_set['X'] = None 312 | else: 313 | inf_set['X'] = dataset.graph['x'] 314 | args.x_dim = inf_set['X'].shape[-1] 315 | 316 | args.w_max = dataset.train_wmax if args.use_weight else None 317 | 318 | return test_dict, val_dict, inf_set 319 | 320 | 321 | def gen_dataset_hyper(dataset, G_enc, args, bsize=10000): 322 | test_pos_edge, test_neg_edge = get_pos_neg_edges('test', dataset.split_edge, ratio=args.test_ratio) 323 | val_pos_edge, val_neg_edge = get_pos_neg_edges('valid', dataset.split_edge, ratio=args.valid_ratio) 324 | 325 | inf_set = {'test': {}, 'val': {}} 326 | 327 | if args.metric == 'mrr': 328 | inf_set['test']['E'] = torch.cat([test_pos_edge, test_neg_edge]) 329 | inf_set['val']['E'] = torch.cat([val_pos_edge, val_neg_edge]) 330 | inf_set['test']['num_pos'], inf_set['val']['num_pos'] = test_pos_edge.shape[0], val_pos_edge.shape[0] 331 | inf_set['test']['num_neg'], inf_set['val']['num_neg'] = test_neg_edge.shape[0] // inf_set['test']['num_pos'], \ 332 | val_neg_edge.shape[0] // inf_set['val']['num_pos'] 333 | else: 334 | raise NotImplementedError 335 | 336 | inf_dict = np_sampling({}, G_enc.indptr, G_enc.indices, 337 | bsize=bsize, 338 | target=torch.unique(torch.cat([inf_set['val']['E'], inf_set['test']['E']])).tolist(), 339 | num_walks=args.num_walk, 340 | num_steps=args.num_step - 1) 341 | 342 | if not args.use_feature: 343 | inf_set['X'] = None 344 | else: 345 | inf_set['X'] = dataset.graph['x'] 346 | args.x_dim = inf_set['X'].shape[-1] 347 | 348 | return inf_dict, inf_set 349 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sgrl_env 2 | channels: 3 | - pyg 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - absl-py=1.3.0=py310h06a4308_0 10 | - aiohttp=3.8.3=py310h5eee18b_0 11 | - aiosignal=1.2.0=pyhd3eb1b0_0 12 | - async-timeout=4.0.2=py310h06a4308_0 13 | - attrs=22.1.0=py310h06a4308_0 14 | - blas=1.0=mkl 15 | - blinker=1.4=py310h06a4308_0 16 | - brotlipy=0.7.0=py310h7f8727e_1002 17 | - bzip2=1.0.8=h7b6447c_0 18 | - c-ares=1.18.1=h7f8727e_0 19 | - ca-certificates=2023.01.10=h06a4308_0 20 | - cachetools=4.2.2=pyhd3eb1b0_0 21 | - certifi=2022.12.7=py310h06a4308_0 22 | - cffi=1.15.1=py310h5eee18b_3 23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 24 | - click=8.0.4=py310h06a4308_0 25 | - cryptography=38.0.4=py310h9ce1e76_0 26 | - cudatoolkit=11.3.1=h2bc3f7f_2 27 | - ffmpeg=4.3=hf484d3e_0 28 | - fftw=3.3.9=h27cfd23_1 29 | - flit-core=3.6.0=pyhd3eb1b0_0 30 | - freetype=2.12.1=h4a9f257_0 31 | - frozenlist=1.3.3=py310h5eee18b_0 32 | - giflib=5.2.1=h7b6447c_0 33 | - gmp=6.2.1=h295c915_3 34 | - gnutls=3.6.15=he1e5248_0 35 | - google-auth=2.6.0=pyhd3eb1b0_0 36 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 37 | - grpcio=1.42.0=py310hce63b2e_0 38 | - idna=3.4=py310h06a4308_0 39 | - intel-openmp=2021.4.0=h06a4308_3561 40 | - jinja2=3.1.2=py310h06a4308_0 41 | - joblib=1.1.1=py310h06a4308_0 42 | - jpeg=9e=h7f8727e_0 43 | - lame=3.100=h7b6447c_0 44 | - lcms2=2.12=h3be6417_0 45 | - ld_impl_linux-64=2.38=h1181459_1 46 | - lerc=3.0=h295c915_0 47 | - libdeflate=1.8=h7f8727e_5 48 | - libffi=3.4.2=h6a678d5_6 49 | - libgcc-ng=11.2.0=h1234567_1 50 | - libgfortran-ng=11.2.0=h00389a5_1 51 | - libgfortran5=11.2.0=h1234567_1 52 | - libgomp=11.2.0=h1234567_1 53 | - libiconv=1.16=h7f8727e_2 54 | - libidn2=2.3.2=h7f8727e_0 55 | - libpng=1.6.37=hbc83047_0 56 | - libprotobuf=3.20.1=h4ff587b_0 57 | - libstdcxx-ng=11.2.0=h1234567_1 58 | - libtasn1=4.16.0=h27cfd23_0 59 | - libtiff=4.5.0=hecacb30_0 60 | - libunistring=0.9.10=h27cfd23_0 61 | - libuuid=1.41.5=h5eee18b_0 62 | - libwebp=1.2.4=h11a3e52_0 63 | - libwebp-base=1.2.4=h5eee18b_0 64 | - lz4-c=1.9.4=h6a678d5_0 65 | - markdown=3.4.1=py310h06a4308_0 66 | - markupsafe=2.1.1=py310h7f8727e_0 67 | - mkl=2021.4.0=h06a4308_640 68 | - mkl-service=2.4.0=py310h7f8727e_0 69 | - mkl_fft=1.3.1=py310hd6ae3a3_0 70 | - mkl_random=1.2.2=py310h00e6091_0 71 | - multidict=6.0.2=py310h5eee18b_0 72 | - ncurses=6.3=h5eee18b_3 73 | - nettle=3.7.3=hbbd107a_1 74 | - numpy=1.23.5=py310hd5efca6_0 75 | - numpy-base=1.23.5=py310h8e6c178_0 76 | - oauthlib=3.2.1=py310h06a4308_0 77 | - openh264=2.1.1=h4ff587b_0 78 | - openssl=1.1.1s=h7f8727e_0 79 | - pillow=9.3.0=py310hace64e9_1 80 | - pip=22.3.1=py310h06a4308_0 81 | - protobuf=3.20.1=py310h295c915_0 82 | - psutil=5.9.0=py310h5eee18b_0 83 | - pyasn1=0.4.8=pyhd3eb1b0_0 84 | - pyasn1-modules=0.2.8=py_0 85 | - pycparser=2.21=pyhd3eb1b0_0 86 | - pyg=2.2.0=py310_torch_1.12.0_cu113 87 | - pyjwt=2.4.0=py310h06a4308_0 88 | - pyopenssl=22.0.0=pyhd3eb1b0_0 89 | - pyparsing=3.0.9=py310h06a4308_0 90 | - pysocks=1.7.1=py310h06a4308_0 91 | - python=3.10.9=h7a1cb2a_0 92 | - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0 93 | - pytorch-cluster=1.6.0=py310_torch_1.12.0_cu113 94 | - pytorch-mutex=1.0=cuda 95 | - pytorch-scatter=2.1.0=py310_torch_1.12.0_cu113 96 | - pytorch-sparse=0.6.16=py310_torch_1.12.0_cu113 97 | - readline=8.2=h5eee18b_0 98 | - requests=2.28.1=py310h06a4308_0 99 | - requests-oauthlib=1.3.0=py_0 100 | - rsa=4.7.2=pyhd3eb1b0_1 101 | - scikit-learn=1.2.0=py310h6a678d5_0 102 | - scipy=1.9.3=py310hd5efca6_0 103 | - setuptools=65.6.3=py310h06a4308_0 104 | - six=1.16.0=pyhd3eb1b0_1 105 | - sqlite=3.40.1=h5082296_0 106 | - tensorboard=2.10.0=py310h06a4308_0 107 | - tensorboard-data-server=0.6.1=py310h52d8a92_0 108 | - tensorboard-plugin-wit=1.8.1=py310h06a4308_0 109 | - threadpoolctl=2.2.0=pyh0d69192_0 110 | - tk=8.6.12=h1ccaba5_0 111 | - torchaudio=0.12.1=py310_cu113 112 | - torchvision=0.13.1=py310_cu113 113 | - tqdm=4.64.1=py310h06a4308_0 114 | - typing_extensions=4.4.0=py310h06a4308_0 115 | - tzdata=2022g=h04d1e81_0 116 | - urllib3=1.26.14=py310h06a4308_0 117 | - werkzeug=2.2.2=py310h06a4308_0 118 | - wheel=0.37.1=pyhd3eb1b0_0 119 | - xz=5.2.10=h5eee18b_1 120 | - yarl=1.8.1=py310h5eee18b_0 121 | - zlib=1.2.13=h5eee18b_0 122 | - zstd=1.5.2=ha4553b6_0 123 | - pip: 124 | - fastremap==1.13.3 125 | - littleutils==0.2.2 126 | - llvmlite==0.39.1 127 | - numba==0.56.4 128 | - ogb==1.3.5 129 | - outdated==0.2.2 130 | - pandas==1.5.3 131 | - pyg-lib==0.1.0+pt112cu113 132 | - python-dateutil==2.8.2 133 | - pytz==2022.7.1 134 | - streamtologger==2017.1 135 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | import os 6 | import socket 7 | import time 8 | 9 | import numpy as np 10 | import streamtologger 11 | 12 | 13 | def set_up_log(args, sys_argv): 14 | log_dir = args.log_dir 15 | save_dir = os.path.join(args.res_dir, 'model', args.dataset) 16 | dataset_log_dir = os.path.join(log_dir, args.dataset) 17 | if not os.path.exists(save_dir): 18 | os.makedirs(save_dir) 19 | if not os.path.exists(log_dir): 20 | os.mkdir(log_dir) 21 | if not os.path.exists(dataset_log_dir): 22 | os.mkdir(dataset_log_dir) 23 | 24 | args.stamp = time.strftime('%m%d%y_%H%M%S') 25 | file_path = os.path.join(dataset_log_dir, f"{args.stamp}.log") 26 | 27 | logging.basicConfig(level=logging.INFO) 28 | logger = logging.getLogger() 29 | logger.setLevel(logging.DEBUG) 30 | fh = logging.FileHandler(file_path) 31 | fh.setLevel(logging.DEBUG) 32 | ch = logging.StreamHandler() 33 | ch.setLevel(logging.WARN) 34 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 35 | fh.setFormatter(formatter) 36 | ch.setFormatter(formatter) 37 | logger.addHandler(fh) 38 | logger.addHandler(ch) 39 | logger.info('Create log file at {}'.format(file_path)) 40 | logger.info('Command line executed: python ' + ' '.join(sys_argv)) 41 | logger.info('Full args parsed:') 42 | logger.info(args) 43 | if args.debug: 44 | streamtologger.redirect(target=logger) 45 | return logger 46 | 47 | 48 | def save_performance_result(args, logger, metrics, repeat=0): 49 | summary_file = args.summary_file 50 | if summary_file != 'test': 51 | summary_file = os.path.join(args.log_dir, summary_file) 52 | else: 53 | return 54 | dataset = args.dataset 55 | val_metric, no_val_metric = metrics 56 | model_name = '-'.join([args.model, str(args.num_step), str(args.num_walk), str(args.K)]) 57 | seed = args.seed 58 | log_name = os.path.split(logger.handlers[1].baseFilename)[-1] 59 | server = socket.gethostname() 60 | line = '\t'.join( 61 | [dataset, model_name, str(seed), str(round(val_metric, 4)), f'R{repeat}', str(round(no_val_metric, 4)), 62 | log_name, server]) + '\n' 63 | try: 64 | with open(summary_file, 'a') as f: 65 | f.write(line) 66 | except: 67 | raise Warning(f'Unable to write back summary file at {summary_file}.') 68 | 69 | 70 | def save_to_file(dic, args, logger, dtype): 71 | save_dict = dic.copy() 72 | flag = 'W' if args.use_weight else 'R' 73 | 74 | if args.save: 75 | if args.use_val and dtype == 'test': 76 | file_name = f'{args.res_dir}/dict/{args.dataset}_{dtype}_{args.num_step}_{args.num_walk}_{flag}_uval.pt' 77 | else: 78 | file_name = f'{args.res_dir}/dict/{args.dataset}_{dtype}_{args.num_step}_{args.num_walk}_{flag}_wo.pt' 79 | if not os.path.exists(file_name): 80 | save_dict.pop('num') 81 | keys, values = list(save_dict.keys()), list(save_dict.values()) 82 | walks, ids, freqs = zip(*values) 83 | np.savez(file_name, X=keys, Y=ids, W=walks, F=freqs) 84 | logger.info(f'Saved {dtype} set to {file_name}') 85 | else: 86 | logger.info(f'File exists, {dtype} skipped.') 87 | else: 88 | logger.info(f'Converted {dtype} set to tensor.') 89 | save_dict['flag'] = False 90 | return save_dict 91 | 92 | 93 | def log_record(logger, tb, out, dic, b_time, batchIdx): 94 | mode, metric, auc = out['mode'], out['metric'], out['auc'] 95 | dt = time.time() - b_time 96 | if tb is not None: 97 | tb.add_scalar(f"AUC/{mode}", auc, batchIdx) 98 | key_metric, key_auc = f'{mode}_{metric}', f'{mode}_AUC' 99 | 100 | if metric == 'mrr': 101 | out_metric = out['mrr_list'].mean() 102 | if tb is not None: 103 | tb.add_scalar(f"MRR/{mode}", out_metric, batchIdx) 104 | logger.info(f"AUC/{mode}: {auc:.4f}, MRR {out_metric:.4f} # {len(out['mrr_list'])} Time {dt:.2f}s") 105 | dic[key_metric].append(out_metric.item()) 106 | elif 'Hit' in metric: 107 | if tb is not None: 108 | tb.add_scalars(f"Hits/{mode}", out['hits'], batchIdx) 109 | hits = ' '.join([f'{k}: {v:.4f}' for k, v in out['hits'].items()]) 110 | logger.info(f"AUC/{mode}: {auc:.4f}, {hits} # {out['num_pos']} Time {dt:.2f}s") 111 | dic[key_metric].append(out['hits'][metric]) 112 | else: 113 | raise NotImplementedError 114 | 115 | dic[key_auc].append(auc) 116 | val_metric = f'val_{metric}' 117 | len_val = len(dic[val_metric]) 118 | if mode == 'test': 119 | len_test = len(dic[key_metric]) 120 | if len_val > len_test: 121 | idx = np.argmax(dic[val_metric][-len_test:]) - len_test 122 | else: 123 | idx = np.argmax(dic[val_metric]) 124 | logger.info(f'Best {metric}: val {dic[val_metric][idx]:.4f} test {dic[key_metric][idx]:.4f}') 125 | if idx == (len_test - 1): 126 | return True 127 | elif mode == 'val': 128 | if np.argmax(dic[val_metric]) == (len_val - 1): 129 | return True 130 | return False 131 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | 6 | import torch 7 | from surel_gacc import run_sample 8 | from ogb.linkproppred import Evaluator 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torch_geometric.utils import subgraph 11 | 12 | from dataloaders import * 13 | from log import * 14 | from models.model import Net 15 | from train import * 16 | from utils import * 17 | import sys 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser('Interface for SUREL framework') 21 | 22 | # general model and training setting 23 | parser.add_argument('--dataset', type=str, default='ogbl-citation2', help='dataset name', 24 | choices=['ogbl-ppa', 'ogbl-ddi', 'ogbl-citation2', 'ogbl-collab', 'mag']) 25 | parser.add_argument('--model', type=str, default='RNN', help='base model to use', 26 | choices=['RNN', 'MLP', 'Transformer', 'GNN']) 27 | parser.add_argument('--layers', type=int, default=2, help='number of layers') 28 | parser.add_argument('--hidden_dim', type=int, default=64, help='hidden dimension') 29 | parser.add_argument('--x_dim', type=int, default=0, help='dim of raw node features') 30 | parser.add_argument('--data_usage', type=float, default=1.0, help='use partial dataset') 31 | parser.add_argument('--train_ratio', type=float, default=0.05, help='mask partial edges for training') 32 | parser.add_argument('--valid_ratio', type=float, default=0.1, help='use partial valid set') 33 | parser.add_argument('--test_ratio', type=float, default=1.0, help='use partial test set') 34 | parser.add_argument('--metric', type=str, default='mrr', help='metric for evaluating performance', 35 | choices=['auc', 'mrr', 'hit']) 36 | parser.add_argument('--seed', type=int, default=0, help='seed to initialize all the random modules') 37 | parser.add_argument('--gpu_id', type=int, default=0, help='gpu id') 38 | parser.add_argument('--nthread', type=int, default=16, help='number of thread') 39 | 40 | # features and positional encoding 41 | parser.add_argument('--B_size', type=int, default=1500, help='set size of train sampling') 42 | parser.add_argument('--num_walk', type=int, default=100, help='total number of random walks') 43 | parser.add_argument('--num_step', type=int, default=4, help='total steps of random walk') 44 | parser.add_argument('--k', type=int, default=50, help='number of paired negative edges') 45 | parser.add_argument('--directed', type=bool, default=False, help='whether to treat the graph as directed') 46 | parser.add_argument('--use_feature', action='store_true', help='whether to use raw features as input') 47 | parser.add_argument('--use_weight', action='store_true', help='whether to use edge weight as input') 48 | parser.add_argument('--use_degree', action='store_true', help='whether to use node degree as input') 49 | parser.add_argument('--use_htype', action='store_true', help='whether to use node type as input') 50 | parser.add_argument('--use_val', action='store_true', help='whether to use val as input') 51 | parser.add_argument('--norm', type=str, default='all', help='method of normalization') 52 | 53 | # model training 54 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use') 55 | parser.add_argument('--rtest', type=int, default=499, help='step start to test') 56 | parser.add_argument('--eval_steps', type=int, default=100, help='number of steps to test') 57 | parser.add_argument('--batch_size', type=int, default=32, help='mini-batch size (train)') 58 | parser.add_argument('--batch_num', type=int, default=2000, help='mini-batch size (test)') 59 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 60 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate') 61 | parser.add_argument('--l2', type=float, default=0., help='l2 regularization (weight decay)') 62 | parser.add_argument('--patience', type=int, default=5, help='early stopping steps') 63 | parser.add_argument('--repeat', type=int, default=1, help='number of training instances to repeat') 64 | 65 | # logging & debug 66 | parser.add_argument('--log_dir', type=str, default='./log/', help='log directory') 67 | parser.add_argument('--res_dir', type=str, default='./dataset/save', help='resource directory') 68 | parser.add_argument('--stamp', type=str, default='', help='time stamp') 69 | parser.add_argument('--summary_file', type=str, default='result_summary.log', 70 | help='brief summary of training results') 71 | parser.add_argument('--debug', default=False, action='store_true', help='whether to use debug mode') 72 | parser.add_argument('--load_dict', default=False, action='store_true', help='whether to load RPE from files') 73 | parser.add_argument('--save', default=False, action='store_true', help='whether to save RPE to files') 74 | parser.add_argument('--load_model', default=False, action='store_true', 75 | help='whether to load saved model from files') 76 | parser.add_argument('--memo', type=str, help='notes') 77 | 78 | sys_argv = sys.argv 79 | try: 80 | args = parser.parse_args() 81 | except: 82 | parser.print_help() 83 | sys.exit(0) 84 | 85 | set_random_seed(args) 86 | 87 | # customized for each dataset 88 | if 'ddi' in args.dataset: 89 | args.metric = 'Hits@20' 90 | elif 'collab' in args.dataset: 91 | args.metric = 'Hits@50' 92 | elif 'ppa' in args.dataset: 93 | args.metric = 'Hits@100' 94 | elif 'citation' in args.dataset: 95 | args.metric = 'mrr' 96 | else: 97 | raise NotImplementedError 98 | 99 | # setup logger and tensorboard 100 | logger = set_up_log(args, sys_argv) 101 | if args.nthread > 0: 102 | torch.set_num_threads(args.nthread) 103 | logger.info(f"torch num_threads {torch.get_num_threads()}") 104 | tb = SummaryWriter() 105 | 106 | save_dir = f'{args.res_dir}/model/{args.dataset}' 107 | if not os.path.exists(save_dir): 108 | os.mkdir(save_dir) 109 | 110 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu') 111 | prefix = f'{save_dir}/{args.stamp}_{args.num_step}_{args.num_walk}' 112 | 113 | g_class = DEDataset(args.dataset, args.train_ratio, 114 | use_weight=args.use_weight, 115 | use_coalesce=args.use_weight, 116 | use_degree=args.use_degree, 117 | use_val=args.use_val) 118 | evaluator = Evaluator(name=args.dataset) 119 | graphs = g_class.process(logger) 120 | 121 | # edges for negative sampling 122 | T_edge_idx, F_edge_idx = g_class.pos_edge.t().contiguous(), g_class.train_edge.t().contiguous() 123 | 124 | # define model and optim 125 | model = Net(num_layers=args.layers, input_dim=args.num_step, hidden_dim=args.hidden_dim, out_dim=1, 126 | num_walk=args.num_walk, x_dim=args.x_dim, dropout=args.dropout, use_feature=args.use_feature, 127 | use_weight=args.use_weight, use_degree=args.use_degree, use_htype=args.use_htype) 128 | model.to(device) 129 | 130 | logger.info(f'#Model Params {sum(p.numel() for p in model.parameters())}') 131 | 132 | if args.optim == 'adam': 133 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) 134 | else: 135 | raise NotImplementedError 136 | 137 | if args.load_model: 138 | load_checkpoint(model, optimizer, filename=prefix) 139 | 140 | test_dict, val_dict, inf_set = gen_dataset(g_class, graphs, args) 141 | 142 | logger.info( 143 | f'Samples: valid {inf_set["val"]["num_pos"]} by {inf_set["val"]["num_neg"]} ' 144 | f'test {inf_set["test"]["num_pos"]} by {inf_set["test"]["num_neg"]} metric: {args.metric}') 145 | 146 | G_pos, G_train = graphs['pos'], graphs['train'] 147 | num_pos, num_seed, num_cand = len(set(G_pos.indices)), 100, 5 148 | 149 | candidates = G_pos.getnnz(axis=1).argsort()[-num_seed:][::-1] 150 | 151 | rw_dict = {} 152 | B_queues = [] 153 | 154 | for r in range(1, args.repeat + 1): 155 | res_dict = {'test_AUC': [], 'val_AUC': [], f'test_{args.metric}': [], f'val_{args.metric}': []} 156 | model.reset_parameters() 157 | logger.info(f'Running Round {r}') 158 | batchIdx, patience = 0, 0 159 | pools = np.copy(candidates) 160 | np.random.shuffle(B_queues) 161 | while True: 162 | if r <= 1: 163 | seeds = np.random.choice(pools, 5, replace=False) 164 | B_queues.append(sorted(run_sample(G_pos.indptr, G_pos.indices, seeds, thld=args.B_size))) 165 | B_pos = B_queues[batchIdx] 166 | B_w = [b for b in B_pos if b not in rw_dict] 167 | if len(B_w) > 0: 168 | walk_set, freqs = run_walk(G_train.indptr, G_train.indices, B_w, num_walks=args.num_walk, 169 | num_steps=args.num_step - 1, replacement=True) 170 | node_id, node_freq = freqs[:, 0], freqs[:, 1] 171 | rw_dict.update(dict(zip(B_w, zip(walk_set, node_id, node_freq)))) 172 | else: 173 | if batchIdx >= len(B_queues): 174 | break 175 | else: 176 | B_pos = B_queues[batchIdx] 177 | batchIdx += 1 178 | 179 | # obtain set of walks, node id and DE (counts) from the dictionary 180 | S, K, F = zip(*itemgetter(*B_pos)(rw_dict)) 181 | B_pos_edge, _ = subgraph(list(B_pos), T_edge_idx) 182 | B_full_edge, _ = subgraph(list(B_pos), F_edge_idx) 183 | data = gen_sample(np.asarray(S), B_pos, K, B_pos_edge, B_full_edge, inf_set['X'], args, gtype=g_class.gtype) 184 | F = np.concatenate(F) 185 | mF = torch.from_numpy(np.concatenate([[[0] * F.shape[-1]], F])).to(device) 186 | gT = normalization(mF, args) 187 | loss, auc = train(model, optimizer, data, gT) 188 | logger.info(f'Batch {batchIdx}\tW{len(rw_dict)}/D{num_pos}\tLoss: {loss:.4f}, AUC: {auc:.4f}') 189 | tb.add_scalar("Loss/train", loss, batchIdx) 190 | tb.add_scalar("AUC/train", auc, batchIdx) 191 | 192 | if batchIdx > args.rtest and batchIdx % args.eval_steps == 0: 193 | bvtime = time.time() 194 | out = eval_model(model, val_dict, inf_set, args, evaluator, device, mode='val') 195 | if log_record(logger, tb, out, res_dict, bvtime, batchIdx): 196 | patience = 0 197 | bttime = time.time() 198 | out = eval_model(model, test_dict, inf_set, args, evaluator, device, mode='test') 199 | if log_record(logger, tb, out, res_dict, bttime, batchIdx): 200 | checkpoint = {'state_dict': model.state_dict(), 201 | 'optimizer': optimizer.state_dict(), 202 | 'epoch': batchIdx} 203 | save_checkpoint(checkpoint, filename=prefix) 204 | else: 205 | patience += 1 206 | 207 | if patience > args.patience: 208 | break 209 | tb.close() 210 | -------------------------------------------------------------------------------- /main_hetro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | 6 | import torch 7 | from surel_gacc import run_sample 8 | from ogb.linkproppred import Evaluator 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torch_geometric.utils import subgraph 11 | 12 | from dataloaders import * 13 | from log import * 14 | from models.model import Net 15 | from train import * 16 | from utils import * 17 | import sys 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser('Interface for SUREL (Relation Prediction)') 21 | 22 | # general model and training setting 23 | parser.add_argument('--dataset', type=str, default='mag', help='dataset name', 24 | choices=['mag']) 25 | parser.add_argument('--relation', type=str, default='cite', help='relation type', 26 | choices=['write', 'cite']) 27 | parser.add_argument('--model', type=str, default='RNN', help='base model to use', 28 | choices=['RNN', 'MLP', 'Transformer', 'GNN']) 29 | parser.add_argument('--layers', type=int, default=2, help='number of layers') 30 | parser.add_argument('--hidden_dim', type=int, default=64, help='hidden dimension') 31 | parser.add_argument('--x_dim', type=int, default=0, help='dim of raw node features') 32 | parser.add_argument('--data_usage', type=float, default=1.0, help='use partial dataset') 33 | parser.add_argument('--train_ratio', type=float, default=0.05, help='mask partial edges for training') 34 | parser.add_argument('--valid_ratio', type=float, default=0.1, help='use partial valid set') 35 | parser.add_argument('--test_ratio', type=float, default=1.0, help='use partial test set') 36 | parser.add_argument('--metric', type=str, default='mrr', help='metric for evaluating performance', 37 | choices=['auc', 'mrr', 'hit']) 38 | parser.add_argument('--seed', type=int, default=0, help='seed to initialize all the random modules') 39 | parser.add_argument('--gpu_id', type=int, default=0, help='gpu id') 40 | parser.add_argument('--nthread', type=int, default=16, help='number of thread') 41 | 42 | # features and positional encoding 43 | parser.add_argument('--B_size', type=int, default=1500, help='set size of train sampling') 44 | parser.add_argument('--num_walk', type=int, default=100, help='total number of random walks') 45 | parser.add_argument('--num_step', type=int, default=4, help='total steps of random walk') 46 | parser.add_argument('--k', type=int, default=50, help='number of paired negative edges') 47 | parser.add_argument('--directed', type=bool, default=False, help='whether to treat the graph as directed') 48 | parser.add_argument('--use_feature', action='store_true', help='whether to use raw features as input') 49 | parser.add_argument('--use_weight', action='store_true', help='whether to use edge weight as input') 50 | parser.add_argument('--use_degree', action='store_true', help='whether to use node degree as input') 51 | parser.add_argument('--use_htype', action='store_true', help='whether to use node type as input') 52 | parser.add_argument('--use_val', action='store_true', help='whether to use val as input') 53 | parser.add_argument('--norm', type=str, default='all', help='method of normalization') 54 | 55 | # model training 56 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use') 57 | parser.add_argument('--rtest', type=int, default=499, help='step start to test') 58 | parser.add_argument('--eval_steps', type=int, default=200, help='number of steps to test') 59 | parser.add_argument('--batch_size', type=int, default=32, help='mini-batch size (train)') 60 | parser.add_argument('--batch_num', type=int, default=2000, help='mini-batch size (test)') 61 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 62 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate') 63 | parser.add_argument('--l2', type=float, default=0., help='l2 regularization (weight decay)') 64 | parser.add_argument('--patience', type=int, default=5, help='early stopping steps') 65 | parser.add_argument('--repeat', type=int, default=5, help='number of training instances to repeat') 66 | 67 | # logging & debug 68 | parser.add_argument('--log_dir', type=str, default='./log/', help='log directory') 69 | parser.add_argument('--res_dir', type=str, default='./dataset/save', help='resource directory') 70 | parser.add_argument('--stamp', type=str, default='', help='time stamp') 71 | parser.add_argument('--summary_file', type=str, default='result_summary.log', 72 | help='brief summary of training results') 73 | parser.add_argument('--debug', default=False, action='store_true', help='whether to use debug mode') 74 | parser.add_argument('--save', default=False, action='store_true', help='whether to save RPE to files') 75 | parser.add_argument('--load_model', default=False, action='store_true', 76 | help='whether to load saved model from files') 77 | parser.add_argument('--memo', type=str, help='notes') 78 | 79 | sys_argv = sys.argv 80 | try: 81 | args = parser.parse_args() 82 | except: 83 | parser.print_help() 84 | sys.exit(0) 85 | 86 | # customized for each dataset 87 | if 'mag' in args.dataset: 88 | args.metric = 'mrr' 89 | else: 90 | raise NotImplementedError 91 | 92 | # setup logger and tensorboard 93 | logger = set_up_log(args, sys_argv) 94 | if args.nthread > 0: 95 | torch.set_num_threads(args.nthread) 96 | logger.info(f"torch num_threads {torch.get_num_threads()}") 97 | tb = SummaryWriter() 98 | 99 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu') 100 | prefix = f'{args.res_dir}/model/{args.dataset}/{args.stamp}_{args.num_step}_{args.num_walk}' 101 | 102 | g_class = DE_Hetro_Dataset(args.dataset, args.relation) 103 | args.x_dim = len(g_class.node_type) 104 | graphs = g_class.process(logger) 105 | 106 | # edges for negative sampling 107 | T_edge_idx, F_edge_idx = g_class.pos_edge.t().contiguous(), g_class.train_edge.t().contiguous() 108 | 109 | # define model and optim 110 | model = Net(num_layers=args.layers, input_dim=args.num_step, hidden_dim=args.hidden_dim, out_dim=1, 111 | num_walk=args.num_walk, x_dim=args.x_dim, dropout=args.dropout, use_feature=args.use_feature, 112 | use_weight=args.use_weight, use_degree=args.use_degree, use_htype=args.use_htype) 113 | model.to(device) 114 | 115 | if args.optim == 'adam': 116 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) 117 | else: 118 | raise NotImplementedError 119 | 120 | evaluator = Evaluator(name=args.dataset) if 'mag' not in args.dataset else Evaluator(name='ogbl-citation2') 121 | 122 | if args.load_model: 123 | load_checkpoint(model, optimizer, filename=prefix) 124 | 125 | test_dict, val_dict, inf_set = gen_dataset(g_class, graphs, args) 126 | 127 | logger.info( 128 | f'Samples: valid {inf_set["val"]["num_pos"]} by {inf_set["val"]["num_neg"]} ' 129 | f'test {inf_set["test"]["num_pos"]} by {inf_set["test"]["num_neg"]} metric: {args.metric}') 130 | 131 | G_pos, G_train = graphs['pos'], graphs['train'] 132 | num_pos, num_seed, num_cand = len(set(G_pos.indices)), 100, 5 133 | 134 | deg = G_pos.getnnz(axis=1) 135 | candidates = np.concatenate( 136 | [deg[:736389].argsort()[-num_seed // 2:][::-1], deg[736389:].argsort()[-num_seed // 2:][::-1] + 736389]) 137 | 138 | rw_dict = {} 139 | B_queues = [] 140 | 141 | for r in range(1, args.repeat + 1): 142 | res_dict = {'test_AUC': [], 'val_AUC': [], f'test_{args.metric}': [], f'val_{args.metric}': []} 143 | model.reset_parameters() 144 | logger.info(f'Running Round {r}') 145 | batchIdx, patience = 0, 0 146 | pools = np.copy(candidates) 147 | np.random.shuffle(B_queues) 148 | while True: 149 | if r <= 1: 150 | seeds = np.random.choice(pools, 5, replace=False) 151 | B_queues.append(sorted(run_sample(G_pos.indptr, G_pos.indices, seeds, thld=args.B_size))) 152 | B_pos = B_queues[batchIdx] 153 | B_w = [b for b in B_pos if b not in rw_dict] 154 | if len(B_w) > 0: 155 | walk_set, freqs = run_walk(G_train.indptr, G_train.indices, B_w, num_walks=args.num_walk, 156 | num_steps=args.num_step - 1, replacement=True) 157 | node_id, node_freq = freqs[:, 0], freqs[:, 1] 158 | rw_dict.update(dict(zip(B_w, zip(walk_set, node_id, node_freq)))) 159 | else: 160 | if batchIdx >= len(B_queues): 161 | break 162 | else: 163 | B_pos = B_queues[batchIdx] 164 | batchIdx += 1 165 | 166 | # obtain set of walks, node id and RPE (counts) from the dictionary 167 | S, K, F = zip(*itemgetter(*B_pos)(rw_dict)) 168 | B_pos_edge, _ = subgraph(list(B_pos), T_edge_idx) 169 | B_full_edge, _ = subgraph(list(B_pos), F_edge_idx) 170 | data = gen_sample(np.asarray(S), B_pos, K, B_pos_edge, B_full_edge, inf_set['X'], args, gtype=g_class.gtype) 171 | F = np.concatenate(F) 172 | mF = torch.from_numpy(np.concatenate([[[0] * F.shape[-1]], F])).to(device) 173 | gT = normalization(mF, args) 174 | loss, auc = train(model, optimizer, data, gT) 175 | logger.info(f'Batch {batchIdx}\tW{len(rw_dict)}/D{num_pos}\tLoss: {loss:.4f}, AUC: {auc:.4f}') 176 | tb.add_scalar("Loss/train", loss, batchIdx) 177 | tb.add_scalar("AUC/train", auc, batchIdx) 178 | 179 | if batchIdx > args.rtest and batchIdx % args.eval_steps == 0: 180 | bvtime = time.time() 181 | out = eval_model(model, val_dict, inf_set, args, evaluator, device, mode='val') 182 | if log_record(logger, tb, out, res_dict, bvtime, batchIdx): 183 | patience = 0 184 | bttime = time.time() 185 | out = eval_model(model, test_dict, inf_set, args, evaluator, device, mode='test') 186 | if log_record(logger, tb, out, res_dict, bttime, batchIdx): 187 | checkpoint = {'state_dict': model.state_dict(), 188 | 'optimizer': optimizer.state_dict(), 189 | 'epoch': batchIdx} 190 | save_checkpoint(checkpoint, filename=prefix) 191 | else: 192 | patience += 1 193 | 194 | if patience > args.patience: 195 | break 196 | tb.close() 197 | -------------------------------------------------------------------------------- /main_horder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import sys 6 | 7 | from ogb.linkproppred import Evaluator 8 | from torch.utils.data import DataLoader 9 | 10 | from dataloaders import * 11 | from log import * 12 | from models.model_horder import HONet 13 | from train import * 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser('Interface for SUREL (Higher-Order Prediction)') 17 | 18 | # general model and training setting 19 | parser.add_argument('--dataset', type=str, default='DBLP-coauthor', help='dataset name', 20 | choices=['DBLP-coauthor', 'tags-math']) 21 | parser.add_argument('--model', type=str, default='RNN', help='base model to use', 22 | choices=['RNN', 'MLP', 'Transformer', 'GNN']) 23 | parser.add_argument('--layers', type=int, default=2, help='number of layers') 24 | parser.add_argument('--hidden_dim', type=int, default=64, help='hidden dimension') 25 | parser.add_argument('--x_dim', type=int, default=0, help='dim of raw node features') 26 | parser.add_argument('--data_usage', type=float, default=1.0, help='use partial dataset') 27 | parser.add_argument('--train_ratio', type=float, default=0.6, help='mask partial edges for training') 28 | parser.add_argument('--valid_ratio', type=float, default=0.1, help='use partial valid set') 29 | parser.add_argument('--test_ratio', type=float, default=1.0, help='use partial test set') 30 | parser.add_argument('--metric', type=str, default='mrr', help='metric for evaluating performance', 31 | choices=['auc', 'mrr', 'hit']) 32 | parser.add_argument('--seed', type=int, default=0, help='seed to initialize all the random modules') 33 | parser.add_argument('--gpu_id', type=int, default=1, help='gpu id') 34 | parser.add_argument('--nthread', type=int, default=16, help='number of thread') 35 | 36 | # features and positional encoding 37 | parser.add_argument('--B_size', type=int, default=1500, help='set size of train sampling') 38 | parser.add_argument('--num_walk', type=int, default=100, help='total number of random walks') 39 | parser.add_argument('--num_step', type=int, default=3, help='total steps of random walk') 40 | parser.add_argument('--k', type=int, default=10, help='number of paired negative edges') 41 | parser.add_argument('--directed', type=bool, default=False, help='whether to treat the graph as directed') 42 | parser.add_argument('--use_feature', action='store_true', help='whether to use raw features as input') 43 | parser.add_argument('--use_weight', action='store_true', help='whether to use edge weight as input') 44 | parser.add_argument('--use_degree', action='store_true', help='whether to use node degree as input') 45 | parser.add_argument('--use_val', action='store_true', help='whether to use val as input') 46 | parser.add_argument('--norm', type=str, default='all', help='method of normalization') 47 | 48 | # model training 49 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use') 50 | parser.add_argument('--rtest', type=int, default=499, help='step start to test') 51 | parser.add_argument('--eval_steps', type=int, default=100, help='number of steps to test') 52 | parser.add_argument('--batch_size', type=int, default=32, help='mini-batch size (train)') 53 | parser.add_argument('--batch_num', type=int, default=2000, help='mini-batch size (test)') 54 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 55 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate') 56 | parser.add_argument('--l2', type=float, default=0., help='l2 regularization (weight decay)') 57 | parser.add_argument('--patience', type=int, default=3, help='early stopping steps') 58 | parser.add_argument('--repeat', type=int, default=5, help='number of training instances to repeat') 59 | 60 | # logging & debug 61 | parser.add_argument('--log_dir', type=str, default='./log/', help='log directory') 62 | parser.add_argument('--res_dir', type=str, default='./dataset/save', help='resource directory') 63 | parser.add_argument('--stamp', type=str, default='', help='time stamp') 64 | parser.add_argument('--summary_file', type=str, default='result_summary.log', 65 | help='brief summary of training results') 66 | parser.add_argument('--debug', default=False, action='store_true', help='whether to use debug mode') 67 | parser.add_argument('--save', default=False, action='store_true', help='whether to save RPE to files') 68 | parser.add_argument('--load_model', default=False, action='store_true', 69 | help='whether to load saved model from files') 70 | parser.add_argument('--memo', type=str, help='notes') 71 | 72 | sys_argv = sys.argv 73 | try: 74 | args = parser.parse_args() 75 | except: 76 | parser.print_help() 77 | sys.exit(0) 78 | 79 | # setup logger and tensorboard 80 | logger = set_up_log(args, sys_argv) 81 | if args.nthread > 0: 82 | torch.set_num_threads(args.nthread) 83 | logger.info(f"torch num_threads {torch.get_num_threads()}") 84 | 85 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu') 86 | prefix = f'{args.res_dir}/model/{args.dataset}/{args.stamp}_{args.num_step}_{args.num_walk}' 87 | g_class = DE_Hyper_Dataset(args.dataset) 88 | G_enc = g_class.process(logger) 89 | 90 | # define model and optim 91 | model = HONet(num_layers=args.layers, input_dim=args.num_step, hidden_dim=args.hidden_dim, out_dim=1, 92 | num_walk=args.num_walk, x_dim=args.x_dim, dropout=args.dropout) 93 | model.to(device) 94 | 95 | if args.optim == 'adam': 96 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) 97 | else: 98 | raise NotImplementedError 99 | 100 | evaluator = Evaluator(name='ogbl-citation2') 101 | 102 | if args.load_model: 103 | load_checkpoint(model, optimizer, filename=prefix) 104 | 105 | inf_dict, inf_set = gen_dataset_hyper(g_class, G_enc, args) 106 | 107 | logger.info( 108 | f'Samples: valid {inf_set["val"]["num_pos"]} by {inf_set["val"]["num_neg"]} ' 109 | f'test {inf_set["test"]["num_pos"]} by {inf_set["test"]["num_neg"]} metric: {args.metric}') 110 | 111 | rw_dict = {} 112 | triplets = g_class.split_edge['train']['hedge'] 113 | num_pos = len(set(G_enc.indices)) 114 | loader = DataLoader(range(len(triplets)), args.batch_size, shuffle=True) 115 | 116 | num_batch = 0 117 | for r in range(1, args.repeat + 1): 118 | res_dict = {'test_AUC': [], 'val_AUC': [], f'test_{args.metric}': [], f'val_{args.metric}': []} 119 | model.reset_parameters() 120 | logger.info(f'Running Round {r}') 121 | batchIdx, patience = 0, 0 122 | for perm in loader: 123 | batchIdx += 1 124 | batch = triplets[perm] 125 | B_pos = np.unique(batch) 126 | B_w = [b for b in B_pos if b not in rw_dict] 127 | if len(B_w) > 0: 128 | walk_set, freqs = run_walk(G_enc.indptr, G_enc.indices, B_w, num_walks=args.num_walk, 129 | num_steps=args.num_step - 1, replacement=True) 130 | node_id, node_freq = freqs[:, 0], freqs[:, 1] 131 | rw_dict.update(dict(zip(B_w, zip(walk_set, node_id, node_freq)))) 132 | 133 | # obtain set of walks, node id and DE (counts) from the dictionary 134 | W, S, F = zip(*itemgetter(*B_pos)(rw_dict)) 135 | data = gen_tuple(np.asarray(W), B_pos, S, batch, args) 136 | F = np.concatenate(F) 137 | mF = torch.from_numpy(np.concatenate([[[0] * F.shape[-1]], F])).to(device) 138 | gT = normalization(mF, args) 139 | loss, auc = train(model, optimizer, data, gT) 140 | logger.info(f'Batch {batchIdx}\tW{len(rw_dict)}/D{num_pos}\tLoss: {loss:.4f}, AUC: {auc:.4f}') 141 | 142 | if batchIdx > args.rtest and batchIdx % args.eval_steps == 0: 143 | bvtime = time.time() 144 | out = eval_model_horder(model, inf_dict, inf_set, args, evaluator, device, mode='val') 145 | if log_record(logger, None, out, res_dict, bvtime, batchIdx): 146 | patience = 0 147 | bttime = time.time() 148 | out = eval_model_horder(model, inf_dict, inf_set, args, evaluator, device, mode='test') 149 | if log_record(logger, None, out, res_dict, bttime, batchIdx): 150 | checkpoint = {'state_dict': model.state_dict(), 151 | 'optimizer': optimizer.state_dict(), 152 | 'epoch': batchIdx} 153 | save_checkpoint(checkpoint, filename=prefix) 154 | else: 155 | patience += 1 156 | if patience > args.patience: 157 | break 158 | -------------------------------------------------------------------------------- /models/layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | # MLP with linear output 10 | class MLP(nn.Module): 11 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 12 | """ 13 | num_layers: number of layers in the neural networks (EXCLUDING the input layer). 14 | If num_layers=1, this reduces to linear model. 15 | input_dim: dimensionality of input features 16 | hidden_dim: dimensionality of hidden units at ALL layers 17 | output_dim: number of classes for prediction 18 | device: which device to use 19 | """ 20 | 21 | super(MLP, self).__init__() 22 | 23 | self.linear_or_not = True # default is linear model 24 | self.num_layers = num_layers 25 | 26 | if num_layers < 1: 27 | raise ValueError("number of layers should be positive!") 28 | elif num_layers == 1: 29 | # Linear model 30 | self.linear = nn.Linear(input_dim, output_dim) 31 | else: 32 | # Multi-layer model 33 | self.linear_or_not = False 34 | self.linears = torch.nn.ModuleList() 35 | self.batch_norms = torch.nn.ModuleList() 36 | 37 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 38 | for layer in range(num_layers - 2): 39 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 40 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 41 | 42 | for layer in range(num_layers - 1): 43 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 44 | 45 | def forward(self, x): 46 | if self.linear_or_not: 47 | # If linear model 48 | return self.linear(x) 49 | else: 50 | # If MLP 51 | h = x 52 | for layer in range(self.num_layers - 1): 53 | h = F.relu(self.batch_norms[layer](self.linears[layer](h))) 54 | return self.linears[self.num_layers - 1](h) 55 | 56 | def reset_parameters(self): 57 | # reset parameters for retraining 58 | if self.num_layers == 1: 59 | self.linear.reset_parameters() 60 | else: 61 | # rest linear layers 62 | for linear in self.linears: 63 | linear.reset_parameters() 64 | # rest normalization layers 65 | for norm in self.batch_norms: 66 | norm.reset_parameters() 67 | 68 | 69 | class RNN(nn.Module): 70 | def __init__(self, num_layers, input_dim, hidden_dim, out_dim, dropout=0.0, mtype='lstm'): 71 | super(RNN, self).__init__() 72 | self.hidden_dim = hidden_dim 73 | self.num_layers = num_layers 74 | self.rnn = nn.LSTM(input_size=input_dim, 75 | hidden_size=hidden_dim, 76 | num_layers=num_layers, 77 | batch_first=True) 78 | self.rnn_type = mtype 79 | self.dropout = dropout 80 | 81 | def forward(self, x, walks): 82 | out, _ = self.rnn(x) 83 | enc = out.select(dim=1, index=-1).view(-1, walks, self.hidden_dim) 84 | enc = F.dropout(enc, p=self.dropout, training=self.training) 85 | enc_agg = torch.mean(enc, dim=1) 86 | return enc_agg 87 | 88 | def init_hidden(self, batch_size): 89 | if self.rnn_type == 'gru': 90 | return torch.zeros(self.num_layers * self.directions_count, batch_size, self.hidden_dim).to(self.device) 91 | elif self.rnn_type == 'lstm': 92 | return ( 93 | torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device), 94 | torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)) 95 | elif self.rnn_type == 'rnn': 96 | return torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device) 97 | else: 98 | raise Exception('Unknown rnn_type. Valid options: "gru", "lstm", or "rnn"') 99 | 100 | def reset_parameters(self): 101 | # reset parameters for retraining 102 | self.rnn.reset_parameters() 103 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | from models.layer import MLP, RNN 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | 9 | 10 | class MergeLayer(torch.nn.Module): 11 | def __init__(self, dim1, dim2, dim3, dim4, non_linear=True, dropout=0.5): 12 | super().__init__() 13 | self.fc1 = nn.Linear(dim1 + dim2, dim3) 14 | self.fc2 = nn.Linear(dim3, dim4) 15 | self.act = nn.ReLU() 16 | self.dropout = dropout 17 | 18 | nn.init.xavier_normal_(self.fc1.weight) 19 | nn.init.xavier_normal_(self.fc2.weight) 20 | 21 | self.non_linear = non_linear 22 | if not non_linear: 23 | assert (dim1 == dim2) 24 | self.fc = nn.Linear(dim1, 1) 25 | nn.init.xavier_normal_(self.fc1.weight) 26 | 27 | def forward(self, x1, x2): 28 | z_walk = None 29 | if self.non_linear: 30 | x = torch.cat([x1, x2], dim=-1) 31 | h = self.act(self.fc1(x)) 32 | h = F.dropout(h, p=self.dropout, training=self.training) 33 | z = self.fc2(h) 34 | else: 35 | # x1, x2 shape: [B, M, F] 36 | x = torch.cat([x1, x2], dim=-2) # x shape: [B, 2M, F] 37 | z_walk = self.fc(x).squeeze(-1) # z_walk shape: [B, 2M] 38 | z = z_walk.sum(dim=-1, keepdim=True) # z shape [B, 1] 39 | return z, z_walk 40 | 41 | def reset_parameter(self): 42 | self.fc1.reset_parameters() 43 | self.fc2.reset_parameters() 44 | nn.init.xavier_normal_(self.fc1.weight) 45 | nn.init.xavier_normal_(self.fc2.weight) 46 | 47 | 48 | class Net(torch.nn.Module): 49 | def __init__(self, num_layers, input_dim, hidden_dim, out_dim, num_walk, x_dim=0, dropout=0.5, 50 | use_feature=False, use_weight=False, use_degree=False, use_htype=False): 51 | super(Net, self).__init__() 52 | self.use_feature = use_feature 53 | self.use_degree = use_degree 54 | self.use_htype = use_htype 55 | self.dropout = dropout 56 | self.x_dim = x_dim 57 | self.enc = 'LP' # landing prob at [0, 1, ... num_layers] 58 | 59 | add_dim = 1 if use_weight else 0 60 | self.trainable_embedding = nn.Sequential(nn.Linear(in_features=input_dim + add_dim, out_features=hidden_dim), 61 | nn.ReLU(), nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) 62 | print("Relative Positional Encoding: {}".format(self.enc)) 63 | if use_feature: 64 | self.rnn = RNN(num_layers, hidden_dim * 2, hidden_dim, out_dim) 65 | elif use_htype: 66 | self.rnn = RNN(num_layers, hidden_dim + x_dim, hidden_dim, out_dim) 67 | else: 68 | self.rnn = RNN(num_layers, hidden_dim, hidden_dim, out_dim) 69 | if use_htype: 70 | self.ntype_embedding = nn.Sequential(nn.Linear(in_features=x_dim, out_features=hidden_dim), nn.ReLU(), 71 | nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) 72 | self.affinity_score = MergeLayer(hidden_dim, hidden_dim, hidden_dim, 1, non_linear=True, dropout=dropout) 73 | self.concat_norm = nn.LayerNorm(hidden_dim * 2) 74 | self.len_step = input_dim 75 | self.walks = num_walk 76 | 77 | def forward(self, x, feature=None, debugs=None): 78 | # out shape [2 (u,v), batch*num_walk, 2 (l,r), pos_dim] 79 | x = self.trainable_embedding(x).sum(dim=-2) 80 | 81 | if self.use_degree: 82 | deg = torch.cat(feature[-1]).to(x.device) 83 | x = x / deg 84 | elif self.use_feature: 85 | x = torch.cat([x, feature[0].to(x.device)], dim=-1) 86 | elif self.use_htype: 87 | ntype = F.one_hot(feature[0], self.x_dim).to(x.device).float() 88 | x = torch.cat([ntype, x], dim=-1) 89 | x = x.view(2, -1, self.len_step, x.shape[-1]) 90 | out_i, out_j = self.rnn(x[0], self.walks), self.rnn(x[1], self.walks) 91 | score, _ = self.affinity_score(out_i, out_j) 92 | return score.squeeze(1) 93 | 94 | def reset_parameters(self): 95 | for layer in self.trainable_embedding: 96 | if hasattr(layer, 'reset_parameters'): 97 | layer.reset_parameters() 98 | nn.init.xavier_normal_(layer.weight) 99 | self.rnn.reset_parameters() 100 | self.affinity_score.reset_parameter() 101 | 102 | 103 | class LinkPredictor(torch.nn.Module): 104 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 105 | dropout): 106 | super(LinkPredictor, self).__init__() 107 | 108 | self.lins = torch.nn.ModuleList() 109 | self.lins.append(torch.nn.Linear(in_channels, hidden_channels)) 110 | for _ in range(num_layers - 2): 111 | self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels)) 112 | self.lins.append(torch.nn.Linear(hidden_channels, out_channels)) 113 | self.dropout = dropout 114 | 115 | def reset_parameters(self): 116 | for lin in self.lins: 117 | lin.reset_parameters() 118 | 119 | def forward(self, x_i, x_j): 120 | x = x_i * x_j 121 | for lin in self.lins[:-1]: 122 | x = lin(x) 123 | x = F.relu(x) 124 | x = F.dropout(x, p=self.dropout, training=self.training) 125 | x = self.lins[-1](x) 126 | return x 127 | -------------------------------------------------------------------------------- /models/model_horder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | from models.layer import MLP, RNN 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | 9 | 10 | class MergeLayer(torch.nn.Module): 11 | def __init__(self, dim1, dim2, dim3, non_linear=True, dropout=0.5): 12 | super().__init__() 13 | self.fc1 = nn.Linear(dim1 * 4, dim2) 14 | self.fc2 = nn.Linear(dim2, dim3) 15 | self.act = nn.ReLU() 16 | self.dropout = dropout 17 | 18 | nn.init.xavier_normal_(self.fc1.weight) 19 | nn.init.xavier_normal_(self.fc2.weight) 20 | 21 | self.non_linear = non_linear 22 | if not non_linear: 23 | assert (dim1 == dim2 == dim3) 24 | self.fc = nn.Linear(dim1, 1) 25 | nn.init.xavier_normal_(self.fc1.weight) 26 | 27 | def forward(self, x1, x2, x3, x4): 28 | z_walk = None 29 | if self.non_linear: 30 | x = torch.cat([x1, x2, x3, x4], dim=-1) 31 | h = self.act(self.fc1(x)) 32 | h = F.dropout(h, p=self.dropout, training=self.training) 33 | z = self.fc2(h) 34 | else: 35 | x = torch.cat([x1, x2, x3, x4], dim=-2) # x shape: [B, 2M, F] 36 | z_walk = self.fc(x).squeeze(-1) # z_walk shape: [B, 2M] 37 | z = z_walk.sum(dim=-1, keepdim=True) # z shape [B, 1] 38 | return z, z_walk 39 | 40 | def reset_parameter(self): 41 | self.fc1.reset_parameters() 42 | self.fc2.reset_parameters() 43 | 44 | 45 | class HONet(torch.nn.Module): 46 | def __init__(self, num_layers, input_dim, hidden_dim, out_dim, num_walk, x_dim=0, dropout=0.5): 47 | super(HONet, self).__init__() 48 | self.dropout = dropout 49 | self.x_dim = x_dim 50 | self.enc = 'LP' # landing prob at [0, 1, ... num_layers] 51 | 52 | self.trainable_embedding = nn.Sequential(nn.Linear(in_features=input_dim, out_features=hidden_dim), 53 | nn.ReLU(), nn.Linear(in_features=hidden_dim, out_features=hidden_dim)) 54 | 55 | print("Relative Positional Encoding: {}".format(self.enc)) 56 | self.rnn = RNN(num_layers, hidden_dim, hidden_dim, out_dim) 57 | self.affinity_score = MergeLayer(hidden_dim, hidden_dim, 1, non_linear=True, dropout=dropout) 58 | self.concat_norm = nn.LayerNorm(hidden_dim * 2) 59 | self.len_step = input_dim 60 | self.walks = num_walk 61 | 62 | def forward(self, x): 63 | x = self.trainable_embedding(x).sum(dim=-2) 64 | x = x.view(4, -1, self.len_step, x.shape[-1]) 65 | wu, wv, uw, vw = self.rnn(x[0], self.walks), self.rnn(x[1], self.walks), self.rnn(x[2], self.walks), self.rnn( 66 | x[3], self.walks) 67 | score, _ = self.affinity_score(wu, wv, uw, vw) 68 | return score.squeeze(1) 69 | 70 | def reset_parameters(self): 71 | for layer in self.trainable_embedding: 72 | if hasattr(layer, 'reset_parameters'): 73 | layer.reset_parameters() 74 | nn.init.xavier_normal_(layer.weight) 75 | self.rnn.reset_parameters() 76 | self.affinity_score.reset_parameter() 77 | -------------------------------------------------------------------------------- /subg_acc/.gitignore: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/github/gitignore/blob/main/C.gitignore 2 | # Prerequisites 3 | *.d 4 | 5 | # Object files 6 | *.o 7 | *.ko 8 | *.obj 9 | *.elf 10 | 11 | # Linker output 12 | *.ilk 13 | *.map 14 | *.exp 15 | 16 | # Precompiled Headers 17 | *.gch 18 | *.pch 19 | 20 | # Libraries 21 | *.lib 22 | *.a 23 | *.la 24 | *.lo 25 | 26 | # Shared objects (inc. Windows DLLs) 27 | *.dll 28 | *.so 29 | *.so.* 30 | *.dylib 31 | 32 | # Executables 33 | *.exe 34 | *.out 35 | *.app 36 | *.i*86 37 | *.x86_64 38 | *.hex 39 | 40 | # Debug files 41 | *.dSYM/ 42 | *.su 43 | *.idb 44 | *.pdb 45 | 46 | # Kernel Module Compile Results 47 | *.mod* 48 | *.cmd 49 | .tmp_versions/ 50 | modules.order 51 | Module.symvers 52 | Mkfile.old 53 | dkms.conf 54 | -------------------------------------------------------------------------------- /subg_acc/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021-Present, Haoteng Yin 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /subg_acc/README.md: -------------------------------------------------------------------------------- 1 | # Subgraph Operation Accelerator 2 |

3 | 4 | Version 5 | 6 |

7 | 8 | The `subg_acc` package is an extension library based on C and openmp to accelerate subgraph operations in subgraph-based graph representation learning (SGRL) with multithreading enabled. Follow the principles of algorithm system co-design in [SUREL](https://arxiv.org/abs/2202.13538)/[SUREL+](https://github.com/VeritasYin/SUREL_Plus/blob/main/manuscript/SUREL_Plus_Full.pdf), query-level subgraphs (of link/motif) (e.g. ego-network in canonical SGRLs) are decomposed into reusable node-level ones. Currently, `subg_acc` consists of the following methods for the realization of scalable SGRLs: 9 | 10 | - `run_walk` walk-based subgraph sampling 11 | - `run_sample` walk-based sampling of training batches 12 | - `rpe_encoder` relative positional encoding (localized structural feature construction) 13 | - `sjoin` online subgraph joining that rebuilds the query-level subgraph from node-level ones to serve queries (a set of nodes) 14 | 15 | ## Requirements 16 | (Other versions may work, but are untested) 17 | 18 | - python >= 3.8 19 | - gcc >= 8.4 20 | - cmake >= 3.16 21 | - make >= 4.2 22 | 23 | ## Installation 24 | ``` 25 | python setup.py install 26 | ``` 27 | 28 | -------------------------------------------------------------------------------- /subg_acc/graph_acc.c: -------------------------------------------------------------------------------- 1 | #define PY_SSIZE_T_CLEAN 2 | 3 | #include 4 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 5 | #include 6 | #include 7 | #include "uthash.h" 8 | 9 | #define DEBUG 0 10 | 11 | typedef struct item_int { 12 | int key; 13 | int val; 14 | UT_hash_handle hh; 15 | } dict_int; 16 | 17 | static int find_key_int(dict_int *maps, int key) { 18 | dict_int *s; 19 | HASH_FIND_INT(maps, &key, s); /* s: output pointer */ 20 | return s ? s->val : -1; 21 | } 22 | 23 | void add_item(dict_int **maps, int key) { 24 | dict_int *s; 25 | HASH_FIND_INT(*maps, &key, s); /* s: output pointer */ 26 | if (s == NULL) { 27 | dict_int *k = malloc(sizeof(*k)); 28 | // walk starts from each node (main key) 29 | k->key = key; 30 | HASH_ADD_INT(*maps, key, k); 31 | } 32 | } 33 | 34 | /* hash of hashes */ 35 | typedef struct item { 36 | int key; 37 | int val; 38 | struct item *sub; 39 | UT_hash_handle hh; 40 | } dict_item; 41 | 42 | static int find_key_item(dict_item *items, int key) { 43 | dict_item *s; 44 | HASH_FIND_INT(items, &key, s); /* s: output pointer */ 45 | return s ? s->val : -1; 46 | } 47 | 48 | static int find_idx(dict_item *items, int id1, int id2) { 49 | dict_item *s, *p; 50 | HASH_FIND_INT(items, &id1, s); /* s: output pointer */ 51 | if (s != NULL) { 52 | HASH_FIND_INT(s->sub, &id2, p); 53 | return p ? p->val : 0; 54 | } else { 55 | return -1; 56 | } 57 | } 58 | 59 | void delete_all(dict_item *maps) { 60 | dict_item *item1, *item2, *tmp1, *tmp2; 61 | 62 | /* clean up both hash tables */ 63 | HASH_ITER(hh, maps, item1, tmp1) { 64 | HASH_ITER(hh, item1->sub, item2, tmp2) { 65 | HASH_DEL(item1->sub, item2); 66 | free(item2); 67 | } 68 | HASH_DEL(maps, item1); 69 | free(item1); 70 | } 71 | } 72 | 73 | // test func 74 | static PyObject *adds(PyObject *self, PyObject *args) { 75 | int arg1, arg2; 76 | if (!(PyArg_ParseTuple(args, "ii", &arg1, &arg2))) { 77 | return NULL; 78 | } 79 | return Py_BuildValue("i", arg1 * 2 + arg2 * 7); 80 | } 81 | 82 | static PyObject *exe(PyObject *self, PyObject *args) { 83 | const char *command; 84 | int sts; 85 | 86 | if (!PyArg_ParseTuple(args, "s", &command)) 87 | return NULL; 88 | sts = system(command); 89 | return PyLong_FromLong(sts); 90 | } 91 | 92 | static void f_format(const npy_intp *dims, int *CArrays) { 93 | for (int x = 0; x < dims[0]; x++) { 94 | printf("idx %d: \n", x); 95 | for (int y = 0; y < dims[1]; y++) { 96 | printf("%d ", CArrays[x * dims[1] + y]); 97 | } 98 | printf("\n"); 99 | } 100 | } 101 | 102 | static void 103 | random_walk(int const *ptr, int const *neighs, int const *seq, int n, int num_walks, int num_steps, int seed, 104 | int nthread, int *walks) { 105 | /* https://github.com/lkskstlr/rwalk */ 106 | if (DEBUG) { 107 | printf("get in with n: %d, num_walks: %d, num_steps: %d, seed: %d, nthread: %d\n", n, num_walks, num_steps, 108 | seed, nthread); 109 | } 110 | if (nthread > 0) { 111 | omp_set_num_threads(nthread); 112 | } 113 | #pragma omp parallel 114 | { 115 | int thread_num = omp_get_thread_num(); 116 | unsigned int private_seed = (unsigned int) (seed + thread_num); 117 | #pragma omp for 118 | for (int i = 0; i < n; i++) { 119 | int offset, num_neighs; 120 | for (int walk = 0; walk < num_walks; walk++) { 121 | int curr = seq[i]; 122 | offset = i * num_walks * (num_steps + 1) + walk * (num_steps + 1); 123 | walks[offset] = curr; 124 | for (int step = 0; step < num_steps; step++) { 125 | num_neighs = ptr[curr + 1] - ptr[curr]; 126 | if (num_neighs > 0) { 127 | curr = neighs[ptr[curr] + (rand_r(&private_seed) % num_neighs)]; 128 | } 129 | walks[offset + step + 1] = curr; 130 | } 131 | } 132 | } 133 | } 134 | } 135 | 136 | // random walk without replacement (1st neigh) 137 | static void 138 | random_walk_wo(int const *ptr, int const *neighs, int const *seq, int n, int num_walks, int num_steps, int seed, 139 | int nthread, int *walks) { 140 | if (nthread > 0) { 141 | omp_set_num_threads(nthread); 142 | } 143 | #pragma omp parallel 144 | { 145 | int thread_num = omp_get_thread_num(); 146 | unsigned int private_seed = (unsigned int) (seed + thread_num); 147 | 148 | #pragma omp for 149 | for (int i = 0; i < n; i++) { 150 | int offset, num_neighs; 151 | 152 | int num_hop1 = ptr[seq[i] + 1] - ptr[seq[i]]; 153 | int rseq[num_hop1]; 154 | if (num_hop1 > num_walks) { 155 | // https://www.programmersought.com/article/71554044511/ 156 | int s, t; 157 | for (int j = 0; j < num_hop1; j++) 158 | rseq[j] = j; 159 | for (int k = 0; k < num_walks; k++) { 160 | s = rand_r(&private_seed) % (num_hop1 - k) + k; 161 | t = rseq[k]; 162 | rseq[k] = rseq[s]; 163 | rseq[s] = t; 164 | } 165 | } 166 | 167 | for (int walk = 0; walk < num_walks; walk++) { 168 | int curr = seq[i]; 169 | offset = i * num_walks * (num_steps + 1) + walk * (num_steps + 1); 170 | walks[offset] = curr; 171 | if (num_hop1 < 1){ 172 | walks[offset + 1] = curr; 173 | } 174 | else if (num_hop1 <= num_walks) { 175 | curr = neighs[ptr[curr] + walk % num_hop1]; 176 | walks[offset + 1] = curr; 177 | } else { 178 | curr = neighs[ptr[curr] + rseq[walk]]; 179 | walks[offset + 1] = curr; 180 | } 181 | for (int step = 1; step < num_steps; step++) { 182 | num_neighs = ptr[curr + 1] - ptr[curr]; 183 | if (num_neighs > 0) { 184 | curr = neighs[ptr[curr] + (rand_r(&private_seed) % num_neighs)]; 185 | } 186 | walks[offset + step + 1] = curr; 187 | } 188 | } 189 | } 190 | } 191 | } 192 | 193 | 194 | void rpe_encoder(int const *arr, int idx, int num_walks, int num_steps, PyArrayObject **out) { 195 | PyArrayObject *oarr1 = NULL, *oarr2 = NULL; 196 | dict_int *mapping = NULL; 197 | int offset = idx * num_walks * (num_steps + 1); 198 | 199 | // setup root node 200 | dict_int *root = malloc(sizeof(*root)); 201 | root->key = arr[offset]; 202 | root->val = 0; 203 | HASH_ADD_INT(mapping, key, root); 204 | 205 | // setup the rest unique node 206 | int count = 1; 207 | for (int i = 1; i < num_steps + 1; i++) { 208 | for (int j = 0; j < num_walks; j++) { 209 | int token = arr[offset + j * (num_steps + 1) + i]; 210 | if (find_key_int(mapping, token) < 0) { 211 | dict_int *k = malloc(sizeof(*k)); 212 | // walk starts from each node (main key) 213 | k->key = token; 214 | k->val = count; 215 | HASH_ADD_INT(mapping, key, k); 216 | count++; 217 | } 218 | } 219 | } 220 | int num_keys = HASH_COUNT(mapping); 221 | 222 | // create a new array 223 | npy_intp odims1[2] = {num_keys, num_steps + 1}; 224 | oarr1 = (PyArrayObject *) PyArray_ZEROS(2, odims1, NPY_INT, 0); 225 | if (!oarr1) PyErr_SetString(PyExc_TypeError, "output error."); 226 | int *Coarr1 = (int *) PyArray_DATA(oarr1); 227 | 228 | npy_intp odims2[1] = {num_keys}; 229 | oarr2 = (PyArrayObject *) PyArray_SimpleNew(1, odims2, NPY_INT); 230 | if (!oarr2) PyErr_SetString(PyExc_TypeError, "output error."); 231 | int *Coarr2 = (int *) PyArray_DATA(oarr2); 232 | 233 | Coarr1[0] = num_walks; 234 | 235 | for (int i = 1; i < num_steps + 1; i++) { 236 | for (int j = 0; j < num_walks; j++) { 237 | int anchor = find_key_int(mapping, arr[offset + j * (num_steps + 1) + i]); 238 | Coarr1[anchor * (num_steps + 1) + i]++; 239 | } 240 | } 241 | 242 | // free mem 243 | dict_int *cur_item, *tmp; 244 | HASH_ITER(hh, mapping, cur_item, tmp) { 245 | Coarr2[cur_item->val] = cur_item->key; 246 | HASH_DEL(mapping, cur_item); /* delete it (users advances to next) */ 247 | free(cur_item); /* free it */ 248 | } 249 | out[2 * idx] = oarr2, out[2 * idx + 1] = oarr1; 250 | } 251 | 252 | static PyObject *np_sample(PyObject *self, PyObject *args, PyObject *kws) { 253 | PyObject *arg1 = NULL, *arg2 = NULL, *query = NULL; 254 | PyArrayObject *ptr = NULL, *neighs = NULL, *seq = NULL, *oarr = NULL; 255 | int num_walks = 200, num_steps = 8, seed = 111413, nthread = -1, thld = 1000; 256 | 257 | static char *kwlist[] = {"ptr", "neighs", "query", "num_walks", "num_steps", "thld", "nthread", "seed", NULL}; 258 | if (!(PyArg_ParseTupleAndKeywords(args, kws, "OOO|iiiiip", kwlist, &arg1, &arg2, &query, &num_walks, &num_steps, 259 | &thld, &nthread, &seed))) { 260 | PyErr_SetString(PyExc_TypeError, "input parsing error."); 261 | return NULL; 262 | } 263 | 264 | /* handle walks (numpy array) */ 265 | ptr = (PyArrayObject *) PyArray_FROM_OTF(arg1, NPY_INT, NPY_ARRAY_IN_ARRAY); 266 | if (!ptr) return NULL; 267 | int *Cptr = PyArray_DATA(ptr); 268 | 269 | neighs = (PyArrayObject *) PyArray_FROM_OTF(arg2, NPY_INT, NPY_ARRAY_IN_ARRAY); 270 | if (!neighs) return NULL; 271 | int *Cneighs = PyArray_DATA(neighs); 272 | 273 | seq = (PyArrayObject *) PyArray_FROM_OTF(query, NPY_INT, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_FORCECAST); 274 | if (!seq) return NULL; 275 | int *Cseq = PyArray_DATA(seq); 276 | 277 | int n = (int) PyArray_SIZE(seq); 278 | 279 | unsigned int private_seed = (unsigned int) (seed + getpid()); 280 | /* initialize the hashtable */ 281 | dict_int *sets = NULL; 282 | for (int i = 0; i < n; i++) { 283 | int num_hop1 = Cptr[Cseq[i] + 1] - Cptr[Cseq[i]]; 284 | int num_neighs, rseq[num_hop1]; 285 | if (num_hop1 > num_walks) { 286 | int s, t; 287 | for (int j = 0; j < num_hop1; j++) 288 | rseq[j] = j; 289 | for (int k = 0; k < num_walks; k++) { 290 | s = rand_r(&private_seed) % (num_hop1 - k) + k; 291 | t = rseq[k]; 292 | rseq[k] = rseq[s]; 293 | rseq[s] = t; 294 | } 295 | } 296 | add_item(&sets, Cseq[i]); 297 | 298 | for (int walk = 0; walk < num_walks; walk++) { 299 | int curr = Cseq[i]; 300 | if (num_hop1 < 1) { 301 | break; 302 | } else if (num_hop1 <= num_walks) { 303 | curr = Cneighs[Cptr[curr] + walk % num_hop1]; 304 | } else { 305 | curr = Cneighs[Cptr[curr] + rseq[walk]]; 306 | } 307 | for (int step = 1; step < num_steps; step++) { 308 | add_item(&sets, curr); 309 | num_neighs = Cptr[curr + 1] - Cptr[curr]; 310 | if (num_neighs > 0) { 311 | curr = Cneighs[Cptr[curr] + (rand_r(&private_seed) % num_neighs)]; 312 | add_item(&sets, curr); 313 | } 314 | } 315 | if ((int) HASH_COUNT(sets) >= ((i + 1) * thld / n)) 316 | break; 317 | } 318 | } 319 | 320 | npy_intp odims[1] = {HASH_COUNT(sets)}; 321 | oarr = (PyArrayObject *) PyArray_SimpleNew(1, odims, NPY_INT); 322 | if (oarr == NULL) goto fail; 323 | int *Coarr = (int *) PyArray_DATA(oarr); 324 | 325 | // free memory 326 | dict_int *cur_item, *tmp; 327 | int idx = 0; 328 | HASH_ITER(hh, sets, cur_item, tmp) { 329 | Coarr[idx] = cur_item->key; 330 | HASH_DEL(sets, cur_item); /* delete it (users advances to next) */ 331 | free(cur_item); /* free it */ 332 | idx++; 333 | } 334 | 335 | Py_DECREF(ptr); 336 | Py_DECREF(neighs); 337 | Py_DECREF(seq); 338 | return PyArray_Return(oarr); 339 | 340 | fail: 341 | Py_XDECREF(ptr); 342 | Py_XDECREF(neighs); 343 | Py_XDECREF(seq); 344 | PyArray_DiscardWritebackIfCopy(oarr); 345 | PyArray_XDECREF(oarr); 346 | return NULL; 347 | } 348 | 349 | static PyObject *np_walk(PyObject *self, PyObject *args, PyObject *kws) { 350 | PyObject *arg1 = NULL, *arg2 = NULL, *query = NULL; 351 | PyArrayObject *ptr = NULL, *neighs = NULL, *seq = NULL, *oarr = NULL, *obj_arr = NULL; 352 | int num_walks = 100, num_steps = 3, seed = 111413, nthread = -1, re = -1; 353 | int n; 354 | 355 | static char *kwlist[] = {"ptr", "neighs", "query", "num_walks", "num_steps", "nthread", "seed", "replacement", NULL}; 356 | if (!(PyArg_ParseTupleAndKeywords(args, kws, "OOO|iiiip", kwlist, &arg1, &arg2, &query, &num_walks, &num_steps, 357 | &nthread, &seed, &re))) { 358 | PyErr_SetString(PyExc_TypeError, "input parsing error."); 359 | return NULL; 360 | } 361 | 362 | /* handle walks (numpy array) */ 363 | ptr = (PyArrayObject *) PyArray_FROM_OTF(arg1, NPY_INT, NPY_ARRAY_IN_ARRAY); 364 | if (!ptr) return NULL; 365 | int *Cptr = PyArray_DATA(ptr); 366 | 367 | neighs = (PyArrayObject *) PyArray_FROM_OTF(arg2, NPY_INT, NPY_ARRAY_IN_ARRAY); 368 | if (!neighs) return NULL; 369 | int *Cneighs = PyArray_DATA(neighs); 370 | 371 | seq = (PyArrayObject *) PyArray_FROM_OTF(query, NPY_INT, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_FORCECAST); 372 | if (!seq) return NULL; 373 | int *Cseq = PyArray_DATA(seq); 374 | 375 | n = (int) PyArray_SIZE(seq); 376 | 377 | npy_intp odims[2] = {n, num_walks * (num_steps + 1)}; 378 | oarr = (PyArrayObject *) PyArray_SimpleNew(2, odims, NPY_INT); 379 | if (oarr == NULL) goto fail; 380 | int *Coarr = (int *) PyArray_DATA(oarr); 381 | 382 | npy_intp obj_dims[2] = {n, 2}; 383 | obj_arr = (PyArrayObject *) PyArray_SimpleNew(2, obj_dims, NPY_OBJECT); 384 | if (obj_arr == NULL) goto fail; 385 | PyArrayObject **Cobj_arr = PyArray_DATA(obj_arr); 386 | 387 | if (re > 0) { 388 | // printf("Using no replacement sampling for the 1-hop.\n"); 389 | random_walk_wo(Cptr, Cneighs, Cseq, n, num_walks, num_steps, seed, nthread, Coarr); 390 | } else { 391 | random_walk(Cptr, Cneighs, Cseq, n, num_walks, num_steps, seed, nthread, Coarr); 392 | } 393 | 394 | #pragma omp for 395 | for (int k = 0; k < n; k++) { 396 | rpe_encoder(Coarr, k, num_walks, num_steps, Cobj_arr); 397 | } 398 | 399 | Py_DECREF(ptr); 400 | Py_DECREF(neighs); 401 | Py_DECREF(seq); 402 | return Py_BuildValue("[N,N]", PyArray_Return(oarr), PyArray_Return(obj_arr)); 403 | 404 | fail: 405 | Py_XDECREF(ptr); 406 | Py_XDECREF(neighs); 407 | Py_XDECREF(seq); 408 | PyArray_DiscardWritebackIfCopy(oarr); 409 | PyArray_XDECREF(oarr); 410 | PyArray_DiscardWritebackIfCopy(obj_arr); 411 | PyArray_XDECREF(obj_arr); 412 | return NULL; 413 | } 414 | 415 | static PyObject *np_join(PyObject *self, PyObject *args, PyObject *kws) { 416 | PyObject *arg1 = NULL, *arg2 = NULL, *query = NULL, *seq = NULL, **src; 417 | PyArrayObject *arr = NULL, *iarr = NULL, *oarr = NULL, *xarr = NULL; 418 | int nthread = -1, re = -1; 419 | 420 | static char *kwlist[] = {"walk", "key", "query", "nthread", "return_idx", NULL}; 421 | if (!(PyArg_ParseTupleAndKeywords(args, kws, "OOO|ip", kwlist, &arg1, &arg2, &query, &nthread, &re))) { 422 | PyErr_SetString(PyExc_TypeError, "input parsing error."); 423 | return NULL; 424 | } 425 | 426 | /* handle walks (numpy array) */ 427 | arr = (PyArrayObject *) PyArray_FROM_OTF(arg1, NPY_INT, NPY_ARRAY_IN_ARRAY); 428 | if (!arr) return NULL; 429 | npy_intp *arr_dims = PyArray_DIMS(arr); 430 | int stride; 431 | if (PyArray_NDIM(arr) > 2) { 432 | stride = (int) (arr_dims[1] * arr_dims[2]); 433 | } else { 434 | stride = (int) arr_dims[1]; 435 | } 436 | int *Carr = (int *) PyArray_DATA(arr); 437 | 438 | /* handle keys (a list of tuple) */ 439 | seq = PySequence_Fast(arg2, "argument must be iterable"); 440 | if (!seq) return NULL; 441 | if (PySequence_Fast_GET_SIZE(seq) != arr_dims[0]) { 442 | PyErr_SetString(PyExc_TypeError, "dims do not match between walks and keys."); 443 | return NULL; 444 | } 445 | 446 | /* handle queries (numpy array/sequence) */ 447 | iarr = (PyArrayObject *) PyArray_FROM_OTF(query, NPY_INT, NPY_ARRAY_IN_ARRAY | NPY_ARRAY_FORCECAST); 448 | if (!iarr) return NULL; 449 | npy_intp *iarr_dims = PyArray_DIMS(iarr); 450 | int *Ciarr = (int *) PyArray_DATA(iarr); 451 | 452 | if (DEBUG) { 453 | printf("Dims of query: %d, %d\n", (int) iarr_dims[0], (int) iarr_dims[1]); 454 | } 455 | 456 | /* initialize the hashtable */ 457 | dict_item *items = NULL; 458 | int idx = 1; 459 | src = PySequence_Fast_ITEMS(seq); 460 | 461 | /* build two level hash table: 1) reindex main keys 2) hash unique node idx associated with each key */ 462 | for (int i = 0; i < arr_dims[0]; i++) { 463 | /* make initial element */ 464 | dict_item *k = malloc(sizeof(*k)); 465 | // walk starts from each node (main key) 466 | k->key = Carr[i * stride]; 467 | k->sub = NULL; 468 | k->val = i; 469 | HASH_ADD_INT(items, key, k); 470 | 471 | PyObject *item = PySequence_Fast(src[i], "argument must be iterable"); 472 | int item_size; 473 | if (!PyArray_CheckExact(item)) { 474 | item_size = PySequence_Fast_GET_SIZE(item); 475 | 476 | for (int j = 0; j < item_size; j++) { 477 | /* add a sub hash table off this element */ 478 | dict_item *w = malloc(sizeof(*w)); 479 | w->key = (int) PyLong_AsLong(PySequence_Fast_GET_ITEM(item, j)); 480 | w->sub = NULL; 481 | w->val = idx; 482 | HASH_ADD_INT(k->sub, key, w); 483 | idx++; 484 | } 485 | } else { 486 | item_size = PyArray_Size(item); 487 | 488 | for (int j = 0; j < item_size; j++) { 489 | /* add a sub hash table off this element */ 490 | dict_item *w = malloc(sizeof(*w)); 491 | w->key = (*(int *) PyArray_GETPTR1((PyArrayObject *) item, j)); 492 | w->sub = NULL; 493 | w->val = idx; 494 | HASH_ADD_INT(k->sub, key, w); 495 | idx++; 496 | } 497 | } 498 | // must add, to avoid memory leakage 499 | Py_DECREF(item); 500 | } 501 | 502 | /* allocate a new return numpy array */ 503 | npy_intp odims[2] = {2, iarr_dims[0] * 2 * stride}; 504 | oarr = (PyArrayObject *) PyArray_SimpleNew(2, odims, NPY_INT); 505 | if (oarr == NULL) goto fail; 506 | int *Coarr = (int *) PyArray_DATA(oarr); 507 | 508 | if (DEBUG) { 509 | printf("Dims of output: %d, %d\n", (int) odims[0], (int) odims[1]); 510 | } 511 | 512 | xarr = (PyArrayObject *) PyArray_SimpleNew(2, iarr_dims, NPY_INT); 513 | if (xarr == NULL) goto fail; 514 | int *Cxarr = (int *) PyArray_DATA(xarr); 515 | 516 | if (nthread > 0) { 517 | omp_set_num_threads(nthread); 518 | } 519 | 520 | #pragma omp parallel for 521 | for (int x = 0; x < iarr_dims[0]; x++) { 522 | int qid = 2 * x; 523 | int key1 = Ciarr[qid], key2 = Ciarr[qid + 1]; 524 | Cxarr[qid] = find_key_item(items, key1), Cxarr[qid+1] = find_key_item(items, key2); 525 | for (int y = 0; y < 2 * stride; y += 2) { 526 | Coarr[qid * stride + y] = find_idx(items, key1, Carr[Cxarr[qid] * stride + y / 2]); 527 | Coarr[qid * stride + y + 1] = find_idx(items, key2, Carr[Cxarr[qid] * stride + y / 2]); 528 | Coarr[odims[1] + qid * stride + y] = find_idx(items, key1, Carr[Cxarr[qid+1] * stride + y / 2]); 529 | Coarr[odims[1] + qid * stride + y + 1] = find_idx(items, key2, Carr[Cxarr[qid+1] * stride + y / 2]); 530 | } 531 | } 532 | 533 | Py_DECREF(arr); 534 | Py_DECREF(iarr); 535 | Py_DECREF(seq); 536 | delete_all(items); 537 | if (re>0){ 538 | return Py_BuildValue("[N,N]", PyArray_Return(oarr), PyArray_Return(xarr)); 539 | }else{ 540 | return PyArray_Return(oarr); 541 | } 542 | 543 | fail: 544 | Py_XDECREF(arr); 545 | Py_XDECREF(iarr); 546 | Py_XDECREF(seq); 547 | delete_all(items); 548 | PyArray_DiscardWritebackIfCopy(oarr); 549 | PyArray_XDECREF(oarr); 550 | PyArray_DiscardWritebackIfCopy(xarr); 551 | PyArray_XDECREF(xarr); 552 | return NULL; 553 | } 554 | 555 | static PyMethodDef GAccMethods[] = { 556 | {"add", adds, METH_VARARGS, "Add ops."}, 557 | {"run", exe, METH_VARARGS, "Execute a shell command."}, 558 | {"sjoin", (PyCFunction) np_join, METH_VARARGS | METH_KEYWORDS, 559 | "RPE (subgraph) join op with a list of pairs (numpy, openmp)."}, 560 | {"run_walk", (PyCFunction) np_walk, METH_VARARGS | METH_KEYWORDS, 561 | "Random walks with RPE encoding (numpy, openmp)."}, 562 | {"run_sample", (PyCFunction) np_sample, METH_VARARGS | METH_KEYWORDS, 563 | "Random sampling (numpy, openmp)."}, 564 | {NULL, NULL, 0, NULL} 565 | }; 566 | 567 | static char gacc_doc[] = "C extension for SUREL framework."; 568 | 569 | static struct PyModuleDef gaccmodule = { 570 | PyModuleDef_HEAD_INIT, 571 | "surel_gacc", /* name of module */ 572 | gacc_doc, /* module documentation, may be NULL */ 573 | -1, /* size of per-interpreter state of the module, 574 | or -1 if the module keeps state in global variables. */ 575 | GAccMethods 576 | }; 577 | 578 | PyMODINIT_FUNC PyInit_surel_gacc(void) { 579 | import_array(); 580 | return PyModule_Create(&gaccmodule); 581 | } 582 | -------------------------------------------------------------------------------- /subg_acc/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | import numpy 3 | 4 | module1 = Extension('surel_gacc', 5 | sources = ['graph_acc.c'], 6 | extra_compile_args=['-fopenmp'], 7 | extra_link_args=['-lgomp'], 8 | include_dirs=[numpy.get_include()]) 9 | 10 | setup (name = 'SUREL_GAcc', 11 | version = '1.1', 12 | description = 'This is a package for accelerated graph operations in SUREL framework.', 13 | ext_modules = [module1]) 14 | -------------------------------------------------------------------------------- /subg_acc/uthash.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2003-2022, Troy D. Hanson https://troydhanson.github.io/uthash/ 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 12 | IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 13 | TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 14 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 15 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 16 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 17 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 18 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 19 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 20 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 21 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 22 | */ 23 | 24 | #ifndef UTHASH_H 25 | #define UTHASH_H 26 | 27 | #define UTHASH_VERSION 2.3.0 28 | 29 | #include /* memcmp, memset, strlen */ 30 | #include /* ptrdiff_t */ 31 | #include /* exit */ 32 | 33 | #if defined(HASH_DEFINE_OWN_STDINT) && HASH_DEFINE_OWN_STDINT 34 | /* This codepath is provided for backward compatibility, but I plan to remove it. */ 35 | #warning "HASH_DEFINE_OWN_STDINT is deprecated; please use HASH_NO_STDINT instead" 36 | typedef unsigned int uint32_t; 37 | typedef unsigned char uint8_t; 38 | #elif defined(HASH_NO_STDINT) && HASH_NO_STDINT 39 | #else 40 | #include /* uint8_t, uint32_t */ 41 | #endif 42 | 43 | /* These macros use decltype or the earlier __typeof GNU extension. 44 | As decltype is only available in newer compilers (VS2010 or gcc 4.3+ 45 | when compiling c++ source) this code uses whatever method is needed 46 | or, for VS2008 where neither is available, uses casting workarounds. */ 47 | #if !defined(DECLTYPE) && !defined(NO_DECLTYPE) 48 | #if defined(_MSC_VER) /* MS compiler */ 49 | #if _MSC_VER >= 1600 && defined(__cplusplus) /* VS2010 or newer in C++ mode */ 50 | #define DECLTYPE(x) (decltype(x)) 51 | #else /* VS2008 or older (or VS2010 in C mode) */ 52 | #define NO_DECLTYPE 53 | #endif 54 | #elif defined(__MCST__) /* Elbrus C Compiler */ 55 | #define DECLTYPE(x) (__typeof(x)) 56 | #elif defined(__BORLANDC__) || defined(__ICCARM__) || defined(__LCC__) || defined(__WATCOMC__) 57 | #define NO_DECLTYPE 58 | #else /* GNU, Sun and other compilers */ 59 | #define DECLTYPE(x) (__typeof(x)) 60 | #endif 61 | #endif 62 | 63 | #ifdef NO_DECLTYPE 64 | #define DECLTYPE(x) 65 | #define DECLTYPE_ASSIGN(dst,src) \ 66 | do { \ 67 | char **_da_dst = (char**)(&(dst)); \ 68 | *_da_dst = (char*)(src); \ 69 | } while (0) 70 | #else 71 | #define DECLTYPE_ASSIGN(dst,src) \ 72 | do { \ 73 | (dst) = DECLTYPE(dst)(src); \ 74 | } while (0) 75 | #endif 76 | 77 | #ifndef uthash_malloc 78 | #define uthash_malloc(sz) malloc(sz) /* malloc fcn */ 79 | #endif 80 | #ifndef uthash_free 81 | #define uthash_free(ptr,sz) free(ptr) /* free fcn */ 82 | #endif 83 | #ifndef uthash_bzero 84 | #define uthash_bzero(a,n) memset(a,'\0',n) 85 | #endif 86 | #ifndef uthash_strlen 87 | #define uthash_strlen(s) strlen(s) 88 | #endif 89 | 90 | #ifndef HASH_FUNCTION 91 | #define HASH_FUNCTION(keyptr,keylen,hashv) HASH_JEN(keyptr, keylen, hashv) 92 | #endif 93 | 94 | #ifndef HASH_KEYCMP 95 | #define HASH_KEYCMP(a,b,n) memcmp(a,b,n) 96 | #endif 97 | 98 | #ifndef uthash_noexpand_fyi 99 | #define uthash_noexpand_fyi(tbl) /* can be defined to log noexpand */ 100 | #endif 101 | #ifndef uthash_expand_fyi 102 | #define uthash_expand_fyi(tbl) /* can be defined to log expands */ 103 | #endif 104 | 105 | #ifndef HASH_NONFATAL_OOM 106 | #define HASH_NONFATAL_OOM 0 107 | #endif 108 | 109 | #if HASH_NONFATAL_OOM 110 | /* malloc failures can be recovered from */ 111 | 112 | #ifndef uthash_nonfatal_oom 113 | #define uthash_nonfatal_oom(obj) do {} while (0) /* non-fatal OOM error */ 114 | #endif 115 | 116 | #define HASH_RECORD_OOM(oomed) do { (oomed) = 1; } while (0) 117 | #define IF_HASH_NONFATAL_OOM(x) x 118 | 119 | #else 120 | /* malloc failures result in lost memory, hash tables are unusable */ 121 | 122 | #ifndef uthash_fatal 123 | #define uthash_fatal(msg) exit(-1) /* fatal OOM error */ 124 | #endif 125 | 126 | #define HASH_RECORD_OOM(oomed) uthash_fatal("out of memory") 127 | #define IF_HASH_NONFATAL_OOM(x) 128 | 129 | #endif 130 | 131 | /* initial number of buckets */ 132 | #define HASH_INITIAL_NUM_BUCKETS 32U /* initial number of buckets */ 133 | #define HASH_INITIAL_NUM_BUCKETS_LOG2 5U /* lg2 of initial number of buckets */ 134 | #define HASH_BKT_CAPACITY_THRESH 10U /* expand when bucket count reaches */ 135 | 136 | /* calculate the element whose hash handle address is hhp */ 137 | #define ELMT_FROM_HH(tbl,hhp) ((void*)(((char*)(hhp)) - ((tbl)->hho))) 138 | /* calculate the hash handle from element address elp */ 139 | #define HH_FROM_ELMT(tbl,elp) ((UT_hash_handle*)(void*)(((char*)(elp)) + ((tbl)->hho))) 140 | 141 | #define HASH_ROLLBACK_BKT(hh, head, itemptrhh) \ 142 | do { \ 143 | struct UT_hash_handle *_hd_hh_item = (itemptrhh); \ 144 | unsigned _hd_bkt; \ 145 | HASH_TO_BKT(_hd_hh_item->hashv, (head)->hh.tbl->num_buckets, _hd_bkt); \ 146 | (head)->hh.tbl->buckets[_hd_bkt].count++; \ 147 | _hd_hh_item->hh_next = NULL; \ 148 | _hd_hh_item->hh_prev = NULL; \ 149 | } while (0) 150 | 151 | #define HASH_VALUE(keyptr,keylen,hashv) \ 152 | do { \ 153 | HASH_FUNCTION(keyptr, keylen, hashv); \ 154 | } while (0) 155 | 156 | #define HASH_FIND_BYHASHVALUE(hh,head,keyptr,keylen,hashval,out) \ 157 | do { \ 158 | (out) = NULL; \ 159 | if (head) { \ 160 | unsigned _hf_bkt; \ 161 | HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _hf_bkt); \ 162 | if (HASH_BLOOM_TEST((head)->hh.tbl, hashval) != 0) { \ 163 | HASH_FIND_IN_BKT((head)->hh.tbl, hh, (head)->hh.tbl->buckets[ _hf_bkt ], keyptr, keylen, hashval, out); \ 164 | } \ 165 | } \ 166 | } while (0) 167 | 168 | #define HASH_FIND(hh,head,keyptr,keylen,out) \ 169 | do { \ 170 | (out) = NULL; \ 171 | if (head) { \ 172 | unsigned _hf_hashv; \ 173 | HASH_VALUE(keyptr, keylen, _hf_hashv); \ 174 | HASH_FIND_BYHASHVALUE(hh, head, keyptr, keylen, _hf_hashv, out); \ 175 | } \ 176 | } while (0) 177 | 178 | #ifdef HASH_BLOOM 179 | #define HASH_BLOOM_BITLEN (1UL << HASH_BLOOM) 180 | #define HASH_BLOOM_BYTELEN (HASH_BLOOM_BITLEN/8UL) + (((HASH_BLOOM_BITLEN%8UL)!=0UL) ? 1UL : 0UL) 181 | #define HASH_BLOOM_MAKE(tbl,oomed) \ 182 | do { \ 183 | (tbl)->bloom_nbits = HASH_BLOOM; \ 184 | (tbl)->bloom_bv = (uint8_t*)uthash_malloc(HASH_BLOOM_BYTELEN); \ 185 | if (!(tbl)->bloom_bv) { \ 186 | HASH_RECORD_OOM(oomed); \ 187 | } else { \ 188 | uthash_bzero((tbl)->bloom_bv, HASH_BLOOM_BYTELEN); \ 189 | (tbl)->bloom_sig = HASH_BLOOM_SIGNATURE; \ 190 | } \ 191 | } while (0) 192 | 193 | #define HASH_BLOOM_FREE(tbl) \ 194 | do { \ 195 | uthash_free((tbl)->bloom_bv, HASH_BLOOM_BYTELEN); \ 196 | } while (0) 197 | 198 | #define HASH_BLOOM_BITSET(bv,idx) (bv[(idx)/8U] |= (1U << ((idx)%8U))) 199 | #define HASH_BLOOM_BITTEST(bv,idx) (bv[(idx)/8U] & (1U << ((idx)%8U))) 200 | 201 | #define HASH_BLOOM_ADD(tbl,hashv) \ 202 | HASH_BLOOM_BITSET((tbl)->bloom_bv, ((hashv) & (uint32_t)((1UL << (tbl)->bloom_nbits) - 1U))) 203 | 204 | #define HASH_BLOOM_TEST(tbl,hashv) \ 205 | HASH_BLOOM_BITTEST((tbl)->bloom_bv, ((hashv) & (uint32_t)((1UL << (tbl)->bloom_nbits) - 1U))) 206 | 207 | #else 208 | #define HASH_BLOOM_MAKE(tbl,oomed) 209 | #define HASH_BLOOM_FREE(tbl) 210 | #define HASH_BLOOM_ADD(tbl,hashv) 211 | #define HASH_BLOOM_TEST(tbl,hashv) (1) 212 | #define HASH_BLOOM_BYTELEN 0U 213 | #endif 214 | 215 | #define HASH_MAKE_TABLE(hh,head,oomed) \ 216 | do { \ 217 | (head)->hh.tbl = (UT_hash_table*)uthash_malloc(sizeof(UT_hash_table)); \ 218 | if (!(head)->hh.tbl) { \ 219 | HASH_RECORD_OOM(oomed); \ 220 | } else { \ 221 | uthash_bzero((head)->hh.tbl, sizeof(UT_hash_table)); \ 222 | (head)->hh.tbl->tail = &((head)->hh); \ 223 | (head)->hh.tbl->num_buckets = HASH_INITIAL_NUM_BUCKETS; \ 224 | (head)->hh.tbl->log2_num_buckets = HASH_INITIAL_NUM_BUCKETS_LOG2; \ 225 | (head)->hh.tbl->hho = (char*)(&(head)->hh) - (char*)(head); \ 226 | (head)->hh.tbl->buckets = (UT_hash_bucket*)uthash_malloc( \ 227 | HASH_INITIAL_NUM_BUCKETS * sizeof(struct UT_hash_bucket)); \ 228 | (head)->hh.tbl->signature = HASH_SIGNATURE; \ 229 | if (!(head)->hh.tbl->buckets) { \ 230 | HASH_RECORD_OOM(oomed); \ 231 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ 232 | } else { \ 233 | uthash_bzero((head)->hh.tbl->buckets, \ 234 | HASH_INITIAL_NUM_BUCKETS * sizeof(struct UT_hash_bucket)); \ 235 | HASH_BLOOM_MAKE((head)->hh.tbl, oomed); \ 236 | IF_HASH_NONFATAL_OOM( \ 237 | if (oomed) { \ 238 | uthash_free((head)->hh.tbl->buckets, \ 239 | HASH_INITIAL_NUM_BUCKETS*sizeof(struct UT_hash_bucket)); \ 240 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ 241 | } \ 242 | ) \ 243 | } \ 244 | } \ 245 | } while (0) 246 | 247 | #define HASH_REPLACE_BYHASHVALUE_INORDER(hh,head,fieldname,keylen_in,hashval,add,replaced,cmpfcn) \ 248 | do { \ 249 | (replaced) = NULL; \ 250 | HASH_FIND_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, replaced); \ 251 | if (replaced) { \ 252 | HASH_DELETE(hh, head, replaced); \ 253 | } \ 254 | HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, &((add)->fieldname), keylen_in, hashval, add, cmpfcn); \ 255 | } while (0) 256 | 257 | #define HASH_REPLACE_BYHASHVALUE(hh,head,fieldname,keylen_in,hashval,add,replaced) \ 258 | do { \ 259 | (replaced) = NULL; \ 260 | HASH_FIND_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, replaced); \ 261 | if (replaced) { \ 262 | HASH_DELETE(hh, head, replaced); \ 263 | } \ 264 | HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, add); \ 265 | } while (0) 266 | 267 | #define HASH_REPLACE(hh,head,fieldname,keylen_in,add,replaced) \ 268 | do { \ 269 | unsigned _hr_hashv; \ 270 | HASH_VALUE(&((add)->fieldname), keylen_in, _hr_hashv); \ 271 | HASH_REPLACE_BYHASHVALUE(hh, head, fieldname, keylen_in, _hr_hashv, add, replaced); \ 272 | } while (0) 273 | 274 | #define HASH_REPLACE_INORDER(hh,head,fieldname,keylen_in,add,replaced,cmpfcn) \ 275 | do { \ 276 | unsigned _hr_hashv; \ 277 | HASH_VALUE(&((add)->fieldname), keylen_in, _hr_hashv); \ 278 | HASH_REPLACE_BYHASHVALUE_INORDER(hh, head, fieldname, keylen_in, _hr_hashv, add, replaced, cmpfcn); \ 279 | } while (0) 280 | 281 | #define HASH_APPEND_LIST(hh, head, add) \ 282 | do { \ 283 | (add)->hh.next = NULL; \ 284 | (add)->hh.prev = ELMT_FROM_HH((head)->hh.tbl, (head)->hh.tbl->tail); \ 285 | (head)->hh.tbl->tail->next = (add); \ 286 | (head)->hh.tbl->tail = &((add)->hh); \ 287 | } while (0) 288 | 289 | #define HASH_AKBI_INNER_LOOP(hh,head,add,cmpfcn) \ 290 | do { \ 291 | do { \ 292 | if (cmpfcn(DECLTYPE(head)(_hs_iter), add) > 0) { \ 293 | break; \ 294 | } \ 295 | } while ((_hs_iter = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->next)); \ 296 | } while (0) 297 | 298 | #ifdef NO_DECLTYPE 299 | #undef HASH_AKBI_INNER_LOOP 300 | #define HASH_AKBI_INNER_LOOP(hh,head,add,cmpfcn) \ 301 | do { \ 302 | char *_hs_saved_head = (char*)(head); \ 303 | do { \ 304 | DECLTYPE_ASSIGN(head, _hs_iter); \ 305 | if (cmpfcn(head, add) > 0) { \ 306 | DECLTYPE_ASSIGN(head, _hs_saved_head); \ 307 | break; \ 308 | } \ 309 | DECLTYPE_ASSIGN(head, _hs_saved_head); \ 310 | } while ((_hs_iter = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->next)); \ 311 | } while (0) 312 | #endif 313 | 314 | #if HASH_NONFATAL_OOM 315 | 316 | #define HASH_ADD_TO_TABLE(hh,head,keyptr,keylen_in,hashval,add,oomed) \ 317 | do { \ 318 | if (!(oomed)) { \ 319 | unsigned _ha_bkt; \ 320 | (head)->hh.tbl->num_items++; \ 321 | HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _ha_bkt); \ 322 | HASH_ADD_TO_BKT((head)->hh.tbl->buckets[_ha_bkt], hh, &(add)->hh, oomed); \ 323 | if (oomed) { \ 324 | HASH_ROLLBACK_BKT(hh, head, &(add)->hh); \ 325 | HASH_DELETE_HH(hh, head, &(add)->hh); \ 326 | (add)->hh.tbl = NULL; \ 327 | uthash_nonfatal_oom(add); \ 328 | } else { \ 329 | HASH_BLOOM_ADD((head)->hh.tbl, hashval); \ 330 | HASH_EMIT_KEY(hh, head, keyptr, keylen_in); \ 331 | } \ 332 | } else { \ 333 | (add)->hh.tbl = NULL; \ 334 | uthash_nonfatal_oom(add); \ 335 | } \ 336 | } while (0) 337 | 338 | #else 339 | 340 | #define HASH_ADD_TO_TABLE(hh,head,keyptr,keylen_in,hashval,add,oomed) \ 341 | do { \ 342 | unsigned _ha_bkt; \ 343 | (head)->hh.tbl->num_items++; \ 344 | HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _ha_bkt); \ 345 | HASH_ADD_TO_BKT((head)->hh.tbl->buckets[_ha_bkt], hh, &(add)->hh, oomed); \ 346 | HASH_BLOOM_ADD((head)->hh.tbl, hashval); \ 347 | HASH_EMIT_KEY(hh, head, keyptr, keylen_in); \ 348 | } while (0) 349 | 350 | #endif 351 | 352 | 353 | #define HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh,head,keyptr,keylen_in,hashval,add,cmpfcn) \ 354 | do { \ 355 | IF_HASH_NONFATAL_OOM( int _ha_oomed = 0; ) \ 356 | (add)->hh.hashv = (hashval); \ 357 | (add)->hh.key = (char*) (keyptr); \ 358 | (add)->hh.keylen = (unsigned) (keylen_in); \ 359 | if (!(head)) { \ 360 | (add)->hh.next = NULL; \ 361 | (add)->hh.prev = NULL; \ 362 | HASH_MAKE_TABLE(hh, add, _ha_oomed); \ 363 | IF_HASH_NONFATAL_OOM( if (!_ha_oomed) { ) \ 364 | (head) = (add); \ 365 | IF_HASH_NONFATAL_OOM( } ) \ 366 | } else { \ 367 | void *_hs_iter = (head); \ 368 | (add)->hh.tbl = (head)->hh.tbl; \ 369 | HASH_AKBI_INNER_LOOP(hh, head, add, cmpfcn); \ 370 | if (_hs_iter) { \ 371 | (add)->hh.next = _hs_iter; \ 372 | if (((add)->hh.prev = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->prev)) { \ 373 | HH_FROM_ELMT((head)->hh.tbl, (add)->hh.prev)->next = (add); \ 374 | } else { \ 375 | (head) = (add); \ 376 | } \ 377 | HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->prev = (add); \ 378 | } else { \ 379 | HASH_APPEND_LIST(hh, head, add); \ 380 | } \ 381 | } \ 382 | HASH_ADD_TO_TABLE(hh, head, keyptr, keylen_in, hashval, add, _ha_oomed); \ 383 | HASH_FSCK(hh, head, "HASH_ADD_KEYPTR_BYHASHVALUE_INORDER"); \ 384 | } while (0) 385 | 386 | #define HASH_ADD_KEYPTR_INORDER(hh,head,keyptr,keylen_in,add,cmpfcn) \ 387 | do { \ 388 | unsigned _hs_hashv; \ 389 | HASH_VALUE(keyptr, keylen_in, _hs_hashv); \ 390 | HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, keyptr, keylen_in, _hs_hashv, add, cmpfcn); \ 391 | } while (0) 392 | 393 | #define HASH_ADD_BYHASHVALUE_INORDER(hh,head,fieldname,keylen_in,hashval,add,cmpfcn) \ 394 | HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, &((add)->fieldname), keylen_in, hashval, add, cmpfcn) 395 | 396 | #define HASH_ADD_INORDER(hh,head,fieldname,keylen_in,add,cmpfcn) \ 397 | HASH_ADD_KEYPTR_INORDER(hh, head, &((add)->fieldname), keylen_in, add, cmpfcn) 398 | 399 | #define HASH_ADD_KEYPTR_BYHASHVALUE(hh,head,keyptr,keylen_in,hashval,add) \ 400 | do { \ 401 | IF_HASH_NONFATAL_OOM( int _ha_oomed = 0; ) \ 402 | (add)->hh.hashv = (hashval); \ 403 | (add)->hh.key = (const void*) (keyptr); \ 404 | (add)->hh.keylen = (unsigned) (keylen_in); \ 405 | if (!(head)) { \ 406 | (add)->hh.next = NULL; \ 407 | (add)->hh.prev = NULL; \ 408 | HASH_MAKE_TABLE(hh, add, _ha_oomed); \ 409 | IF_HASH_NONFATAL_OOM( if (!_ha_oomed) { ) \ 410 | (head) = (add); \ 411 | IF_HASH_NONFATAL_OOM( } ) \ 412 | } else { \ 413 | (add)->hh.tbl = (head)->hh.tbl; \ 414 | HASH_APPEND_LIST(hh, head, add); \ 415 | } \ 416 | HASH_ADD_TO_TABLE(hh, head, keyptr, keylen_in, hashval, add, _ha_oomed); \ 417 | HASH_FSCK(hh, head, "HASH_ADD_KEYPTR_BYHASHVALUE"); \ 418 | } while (0) 419 | 420 | #define HASH_ADD_KEYPTR(hh,head,keyptr,keylen_in,add) \ 421 | do { \ 422 | unsigned _ha_hashv; \ 423 | HASH_VALUE(keyptr, keylen_in, _ha_hashv); \ 424 | HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, keyptr, keylen_in, _ha_hashv, add); \ 425 | } while (0) 426 | 427 | #define HASH_ADD_BYHASHVALUE(hh,head,fieldname,keylen_in,hashval,add) \ 428 | HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, add) 429 | 430 | #define HASH_ADD(hh,head,fieldname,keylen_in,add) \ 431 | HASH_ADD_KEYPTR(hh, head, &((add)->fieldname), keylen_in, add) 432 | 433 | #define HASH_TO_BKT(hashv,num_bkts,bkt) \ 434 | do { \ 435 | bkt = ((hashv) & ((num_bkts) - 1U)); \ 436 | } while (0) 437 | 438 | /* delete "delptr" from the hash table. 439 | * "the usual" patch-up process for the app-order doubly-linked-list. 440 | * The use of _hd_hh_del below deserves special explanation. 441 | * These used to be expressed using (delptr) but that led to a bug 442 | * if someone used the same symbol for the head and deletee, like 443 | * HASH_DELETE(hh,users,users); 444 | * We want that to work, but by changing the head (users) below 445 | * we were forfeiting our ability to further refer to the deletee (users) 446 | * in the patch-up process. Solution: use scratch space to 447 | * copy the deletee pointer, then the latter references are via that 448 | * scratch pointer rather than through the repointed (users) symbol. 449 | */ 450 | #define HASH_DELETE(hh,head,delptr) \ 451 | HASH_DELETE_HH(hh, head, &(delptr)->hh) 452 | 453 | #define HASH_DELETE_HH(hh,head,delptrhh) \ 454 | do { \ 455 | const struct UT_hash_handle *_hd_hh_del = (delptrhh); \ 456 | if ((_hd_hh_del->prev == NULL) && (_hd_hh_del->next == NULL)) { \ 457 | HASH_BLOOM_FREE((head)->hh.tbl); \ 458 | uthash_free((head)->hh.tbl->buckets, \ 459 | (head)->hh.tbl->num_buckets * sizeof(struct UT_hash_bucket)); \ 460 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ 461 | (head) = NULL; \ 462 | } else { \ 463 | unsigned _hd_bkt; \ 464 | if (_hd_hh_del == (head)->hh.tbl->tail) { \ 465 | (head)->hh.tbl->tail = HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->prev); \ 466 | } \ 467 | if (_hd_hh_del->prev != NULL) { \ 468 | HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->prev)->next = _hd_hh_del->next; \ 469 | } else { \ 470 | DECLTYPE_ASSIGN(head, _hd_hh_del->next); \ 471 | } \ 472 | if (_hd_hh_del->next != NULL) { \ 473 | HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->next)->prev = _hd_hh_del->prev; \ 474 | } \ 475 | HASH_TO_BKT(_hd_hh_del->hashv, (head)->hh.tbl->num_buckets, _hd_bkt); \ 476 | HASH_DEL_IN_BKT((head)->hh.tbl->buckets[_hd_bkt], _hd_hh_del); \ 477 | (head)->hh.tbl->num_items--; \ 478 | } \ 479 | HASH_FSCK(hh, head, "HASH_DELETE_HH"); \ 480 | } while (0) 481 | 482 | /* convenience forms of HASH_FIND/HASH_ADD/HASH_DEL */ 483 | #define HASH_FIND_STR(head,findstr,out) \ 484 | do { \ 485 | unsigned _uthash_hfstr_keylen = (unsigned)uthash_strlen(findstr); \ 486 | HASH_FIND(hh, head, findstr, _uthash_hfstr_keylen, out); \ 487 | } while (0) 488 | #define HASH_ADD_STR(head,strfield,add) \ 489 | do { \ 490 | unsigned _uthash_hastr_keylen = (unsigned)uthash_strlen((add)->strfield); \ 491 | HASH_ADD(hh, head, strfield[0], _uthash_hastr_keylen, add); \ 492 | } while (0) 493 | #define HASH_REPLACE_STR(head,strfield,add,replaced) \ 494 | do { \ 495 | unsigned _uthash_hrstr_keylen = (unsigned)uthash_strlen((add)->strfield); \ 496 | HASH_REPLACE(hh, head, strfield[0], _uthash_hrstr_keylen, add, replaced); \ 497 | } while (0) 498 | #define HASH_FIND_INT(head,findint,out) \ 499 | HASH_FIND(hh,head,findint,sizeof(int),out) 500 | #define HASH_ADD_INT(head,intfield,add) \ 501 | HASH_ADD(hh,head,intfield,sizeof(int),add) 502 | #define HASH_REPLACE_INT(head,intfield,add,replaced) \ 503 | HASH_REPLACE(hh,head,intfield,sizeof(int),add,replaced) 504 | #define HASH_FIND_PTR(head,findptr,out) \ 505 | HASH_FIND(hh,head,findptr,sizeof(void *),out) 506 | #define HASH_ADD_PTR(head,ptrfield,add) \ 507 | HASH_ADD(hh,head,ptrfield,sizeof(void *),add) 508 | #define HASH_REPLACE_PTR(head,ptrfield,add,replaced) \ 509 | HASH_REPLACE(hh,head,ptrfield,sizeof(void *),add,replaced) 510 | #define HASH_DEL(head,delptr) \ 511 | HASH_DELETE(hh,head,delptr) 512 | 513 | /* HASH_FSCK checks hash integrity on every add/delete when HASH_DEBUG is defined. 514 | * This is for uthash developer only; it compiles away if HASH_DEBUG isn't defined. 515 | */ 516 | #ifdef HASH_DEBUG 517 | #include /* fprintf, stderr */ 518 | #define HASH_OOPS(...) do { fprintf(stderr, __VA_ARGS__); exit(-1); } while (0) 519 | #define HASH_FSCK(hh,head,where) \ 520 | do { \ 521 | struct UT_hash_handle *_thh; \ 522 | if (head) { \ 523 | unsigned _bkt_i; \ 524 | unsigned _count = 0; \ 525 | char *_prev; \ 526 | for (_bkt_i = 0; _bkt_i < (head)->hh.tbl->num_buckets; ++_bkt_i) { \ 527 | unsigned _bkt_count = 0; \ 528 | _thh = (head)->hh.tbl->buckets[_bkt_i].hh_head; \ 529 | _prev = NULL; \ 530 | while (_thh) { \ 531 | if (_prev != (char*)(_thh->hh_prev)) { \ 532 | HASH_OOPS("%s: invalid hh_prev %p, actual %p\n", \ 533 | (where), (void*)_thh->hh_prev, (void*)_prev); \ 534 | } \ 535 | _bkt_count++; \ 536 | _prev = (char*)(_thh); \ 537 | _thh = _thh->hh_next; \ 538 | } \ 539 | _count += _bkt_count; \ 540 | if ((head)->hh.tbl->buckets[_bkt_i].count != _bkt_count) { \ 541 | HASH_OOPS("%s: invalid bucket count %u, actual %u\n", \ 542 | (where), (head)->hh.tbl->buckets[_bkt_i].count, _bkt_count); \ 543 | } \ 544 | } \ 545 | if (_count != (head)->hh.tbl->num_items) { \ 546 | HASH_OOPS("%s: invalid hh item count %u, actual %u\n", \ 547 | (where), (head)->hh.tbl->num_items, _count); \ 548 | } \ 549 | _count = 0; \ 550 | _prev = NULL; \ 551 | _thh = &(head)->hh; \ 552 | while (_thh) { \ 553 | _count++; \ 554 | if (_prev != (char*)_thh->prev) { \ 555 | HASH_OOPS("%s: invalid prev %p, actual %p\n", \ 556 | (where), (void*)_thh->prev, (void*)_prev); \ 557 | } \ 558 | _prev = (char*)ELMT_FROM_HH((head)->hh.tbl, _thh); \ 559 | _thh = (_thh->next ? HH_FROM_ELMT((head)->hh.tbl, _thh->next) : NULL); \ 560 | } \ 561 | if (_count != (head)->hh.tbl->num_items) { \ 562 | HASH_OOPS("%s: invalid app item count %u, actual %u\n", \ 563 | (where), (head)->hh.tbl->num_items, _count); \ 564 | } \ 565 | } \ 566 | } while (0) 567 | #else 568 | #define HASH_FSCK(hh,head,where) 569 | #endif 570 | 571 | /* When compiled with -DHASH_EMIT_KEYS, length-prefixed keys are emitted to 572 | * the descriptor to which this macro is defined for tuning the hash function. 573 | * The app can #include to get the prototype for write(2). */ 574 | #ifdef HASH_EMIT_KEYS 575 | #define HASH_EMIT_KEY(hh,head,keyptr,fieldlen) \ 576 | do { \ 577 | unsigned _klen = fieldlen; \ 578 | write(HASH_EMIT_KEYS, &_klen, sizeof(_klen)); \ 579 | write(HASH_EMIT_KEYS, keyptr, (unsigned long)fieldlen); \ 580 | } while (0) 581 | #else 582 | #define HASH_EMIT_KEY(hh,head,keyptr,fieldlen) 583 | #endif 584 | 585 | /* The Bernstein hash function, used in Perl prior to v5.6. Note (x<<5+x)=x*33. */ 586 | #define HASH_BER(key,keylen,hashv) \ 587 | do { \ 588 | unsigned _hb_keylen = (unsigned)keylen; \ 589 | const unsigned char *_hb_key = (const unsigned char*)(key); \ 590 | (hashv) = 0; \ 591 | while (_hb_keylen-- != 0U) { \ 592 | (hashv) = (((hashv) << 5) + (hashv)) + *_hb_key++; \ 593 | } \ 594 | } while (0) 595 | 596 | 597 | /* SAX/FNV/OAT/JEN hash functions are macro variants of those listed at 598 | * http://eternallyconfuzzled.com/tuts/algorithms/jsw_tut_hashing.aspx 599 | * (archive link: https://archive.is/Ivcan ) 600 | */ 601 | #define HASH_SAX(key,keylen,hashv) \ 602 | do { \ 603 | unsigned _sx_i; \ 604 | const unsigned char *_hs_key = (const unsigned char*)(key); \ 605 | hashv = 0; \ 606 | for (_sx_i=0; _sx_i < keylen; _sx_i++) { \ 607 | hashv ^= (hashv << 5) + (hashv >> 2) + _hs_key[_sx_i]; \ 608 | } \ 609 | } while (0) 610 | /* FNV-1a variation */ 611 | #define HASH_FNV(key,keylen,hashv) \ 612 | do { \ 613 | unsigned _fn_i; \ 614 | const unsigned char *_hf_key = (const unsigned char*)(key); \ 615 | (hashv) = 2166136261U; \ 616 | for (_fn_i=0; _fn_i < keylen; _fn_i++) { \ 617 | hashv = hashv ^ _hf_key[_fn_i]; \ 618 | hashv = hashv * 16777619U; \ 619 | } \ 620 | } while (0) 621 | 622 | #define HASH_OAT(key,keylen,hashv) \ 623 | do { \ 624 | unsigned _ho_i; \ 625 | const unsigned char *_ho_key=(const unsigned char*)(key); \ 626 | hashv = 0; \ 627 | for(_ho_i=0; _ho_i < keylen; _ho_i++) { \ 628 | hashv += _ho_key[_ho_i]; \ 629 | hashv += (hashv << 10); \ 630 | hashv ^= (hashv >> 6); \ 631 | } \ 632 | hashv += (hashv << 3); \ 633 | hashv ^= (hashv >> 11); \ 634 | hashv += (hashv << 15); \ 635 | } while (0) 636 | 637 | #define HASH_JEN_MIX(a,b,c) \ 638 | do { \ 639 | a -= b; a -= c; a ^= ( c >> 13 ); \ 640 | b -= c; b -= a; b ^= ( a << 8 ); \ 641 | c -= a; c -= b; c ^= ( b >> 13 ); \ 642 | a -= b; a -= c; a ^= ( c >> 12 ); \ 643 | b -= c; b -= a; b ^= ( a << 16 ); \ 644 | c -= a; c -= b; c ^= ( b >> 5 ); \ 645 | a -= b; a -= c; a ^= ( c >> 3 ); \ 646 | b -= c; b -= a; b ^= ( a << 10 ); \ 647 | c -= a; c -= b; c ^= ( b >> 15 ); \ 648 | } while (0) 649 | 650 | #define HASH_JEN(key,keylen,hashv) \ 651 | do { \ 652 | unsigned _hj_i,_hj_j,_hj_k; \ 653 | unsigned const char *_hj_key=(unsigned const char*)(key); \ 654 | hashv = 0xfeedbeefu; \ 655 | _hj_i = _hj_j = 0x9e3779b9u; \ 656 | _hj_k = (unsigned)(keylen); \ 657 | while (_hj_k >= 12U) { \ 658 | _hj_i += (_hj_key[0] + ( (unsigned)_hj_key[1] << 8 ) \ 659 | + ( (unsigned)_hj_key[2] << 16 ) \ 660 | + ( (unsigned)_hj_key[3] << 24 ) ); \ 661 | _hj_j += (_hj_key[4] + ( (unsigned)_hj_key[5] << 8 ) \ 662 | + ( (unsigned)_hj_key[6] << 16 ) \ 663 | + ( (unsigned)_hj_key[7] << 24 ) ); \ 664 | hashv += (_hj_key[8] + ( (unsigned)_hj_key[9] << 8 ) \ 665 | + ( (unsigned)_hj_key[10] << 16 ) \ 666 | + ( (unsigned)_hj_key[11] << 24 ) ); \ 667 | \ 668 | HASH_JEN_MIX(_hj_i, _hj_j, hashv); \ 669 | \ 670 | _hj_key += 12; \ 671 | _hj_k -= 12U; \ 672 | } \ 673 | hashv += (unsigned)(keylen); \ 674 | switch ( _hj_k ) { \ 675 | case 11: hashv += ( (unsigned)_hj_key[10] << 24 ); /* FALLTHROUGH */ \ 676 | case 10: hashv += ( (unsigned)_hj_key[9] << 16 ); /* FALLTHROUGH */ \ 677 | case 9: hashv += ( (unsigned)_hj_key[8] << 8 ); /* FALLTHROUGH */ \ 678 | case 8: _hj_j += ( (unsigned)_hj_key[7] << 24 ); /* FALLTHROUGH */ \ 679 | case 7: _hj_j += ( (unsigned)_hj_key[6] << 16 ); /* FALLTHROUGH */ \ 680 | case 6: _hj_j += ( (unsigned)_hj_key[5] << 8 ); /* FALLTHROUGH */ \ 681 | case 5: _hj_j += _hj_key[4]; /* FALLTHROUGH */ \ 682 | case 4: _hj_i += ( (unsigned)_hj_key[3] << 24 ); /* FALLTHROUGH */ \ 683 | case 3: _hj_i += ( (unsigned)_hj_key[2] << 16 ); /* FALLTHROUGH */ \ 684 | case 2: _hj_i += ( (unsigned)_hj_key[1] << 8 ); /* FALLTHROUGH */ \ 685 | case 1: _hj_i += _hj_key[0]; /* FALLTHROUGH */ \ 686 | default: ; \ 687 | } \ 688 | HASH_JEN_MIX(_hj_i, _hj_j, hashv); \ 689 | } while (0) 690 | 691 | /* The Paul Hsieh hash function */ 692 | #undef get16bits 693 | #if (defined(__GNUC__) && defined(__i386__)) || defined(__WATCOMC__) \ 694 | || defined(_MSC_VER) || defined (__BORLANDC__) || defined (__TURBOC__) 695 | #define get16bits(d) (*((const uint16_t *) (d))) 696 | #endif 697 | 698 | #if !defined (get16bits) 699 | #define get16bits(d) ((((uint32_t)(((const uint8_t *)(d))[1])) << 8) \ 700 | +(uint32_t)(((const uint8_t *)(d))[0]) ) 701 | #endif 702 | #define HASH_SFH(key,keylen,hashv) \ 703 | do { \ 704 | unsigned const char *_sfh_key=(unsigned const char*)(key); \ 705 | uint32_t _sfh_tmp, _sfh_len = (uint32_t)keylen; \ 706 | \ 707 | unsigned _sfh_rem = _sfh_len & 3U; \ 708 | _sfh_len >>= 2; \ 709 | hashv = 0xcafebabeu; \ 710 | \ 711 | /* Main loop */ \ 712 | for (;_sfh_len > 0U; _sfh_len--) { \ 713 | hashv += get16bits (_sfh_key); \ 714 | _sfh_tmp = ((uint32_t)(get16bits (_sfh_key+2)) << 11) ^ hashv; \ 715 | hashv = (hashv << 16) ^ _sfh_tmp; \ 716 | _sfh_key += 2U*sizeof (uint16_t); \ 717 | hashv += hashv >> 11; \ 718 | } \ 719 | \ 720 | /* Handle end cases */ \ 721 | switch (_sfh_rem) { \ 722 | case 3: hashv += get16bits (_sfh_key); \ 723 | hashv ^= hashv << 16; \ 724 | hashv ^= (uint32_t)(_sfh_key[sizeof (uint16_t)]) << 18; \ 725 | hashv += hashv >> 11; \ 726 | break; \ 727 | case 2: hashv += get16bits (_sfh_key); \ 728 | hashv ^= hashv << 11; \ 729 | hashv += hashv >> 17; \ 730 | break; \ 731 | case 1: hashv += *_sfh_key; \ 732 | hashv ^= hashv << 10; \ 733 | hashv += hashv >> 1; \ 734 | break; \ 735 | default: ; \ 736 | } \ 737 | \ 738 | /* Force "avalanching" of final 127 bits */ \ 739 | hashv ^= hashv << 3; \ 740 | hashv += hashv >> 5; \ 741 | hashv ^= hashv << 4; \ 742 | hashv += hashv >> 17; \ 743 | hashv ^= hashv << 25; \ 744 | hashv += hashv >> 6; \ 745 | } while (0) 746 | 747 | /* iterate over items in a known bucket to find desired item */ 748 | #define HASH_FIND_IN_BKT(tbl,hh,head,keyptr,keylen_in,hashval,out) \ 749 | do { \ 750 | if ((head).hh_head != NULL) { \ 751 | DECLTYPE_ASSIGN(out, ELMT_FROM_HH(tbl, (head).hh_head)); \ 752 | } else { \ 753 | (out) = NULL; \ 754 | } \ 755 | while ((out) != NULL) { \ 756 | if ((out)->hh.hashv == (hashval) && (out)->hh.keylen == (keylen_in)) { \ 757 | if (HASH_KEYCMP((out)->hh.key, keyptr, keylen_in) == 0) { \ 758 | break; \ 759 | } \ 760 | } \ 761 | if ((out)->hh.hh_next != NULL) { \ 762 | DECLTYPE_ASSIGN(out, ELMT_FROM_HH(tbl, (out)->hh.hh_next)); \ 763 | } else { \ 764 | (out) = NULL; \ 765 | } \ 766 | } \ 767 | } while (0) 768 | 769 | /* add an item to a bucket */ 770 | #define HASH_ADD_TO_BKT(head,hh,addhh,oomed) \ 771 | do { \ 772 | UT_hash_bucket *_ha_head = &(head); \ 773 | _ha_head->count++; \ 774 | (addhh)->hh_next = _ha_head->hh_head; \ 775 | (addhh)->hh_prev = NULL; \ 776 | if (_ha_head->hh_head != NULL) { \ 777 | _ha_head->hh_head->hh_prev = (addhh); \ 778 | } \ 779 | _ha_head->hh_head = (addhh); \ 780 | if ((_ha_head->count >= ((_ha_head->expand_mult + 1U) * HASH_BKT_CAPACITY_THRESH)) \ 781 | && !(addhh)->tbl->noexpand) { \ 782 | HASH_EXPAND_BUCKETS(addhh,(addhh)->tbl, oomed); \ 783 | IF_HASH_NONFATAL_OOM( \ 784 | if (oomed) { \ 785 | HASH_DEL_IN_BKT(head,addhh); \ 786 | } \ 787 | ) \ 788 | } \ 789 | } while (0) 790 | 791 | /* remove an item from a given bucket */ 792 | #define HASH_DEL_IN_BKT(head,delhh) \ 793 | do { \ 794 | UT_hash_bucket *_hd_head = &(head); \ 795 | _hd_head->count--; \ 796 | if (_hd_head->hh_head == (delhh)) { \ 797 | _hd_head->hh_head = (delhh)->hh_next; \ 798 | } \ 799 | if ((delhh)->hh_prev) { \ 800 | (delhh)->hh_prev->hh_next = (delhh)->hh_next; \ 801 | } \ 802 | if ((delhh)->hh_next) { \ 803 | (delhh)->hh_next->hh_prev = (delhh)->hh_prev; \ 804 | } \ 805 | } while (0) 806 | 807 | /* Bucket expansion has the effect of doubling the number of buckets 808 | * and redistributing the items into the new buckets. Ideally the 809 | * items will distribute more or less evenly into the new buckets 810 | * (the extent to which this is true is a measure of the quality of 811 | * the hash function as it applies to the key domain). 812 | * 813 | * With the items distributed into more buckets, the chain length 814 | * (item count) in each bucket is reduced. Thus by expanding buckets 815 | * the hash keeps a bound on the chain length. This bounded chain 816 | * length is the essence of how a hash provides constant time lookup. 817 | * 818 | * The calculation of tbl->ideal_chain_maxlen below deserves some 819 | * explanation. First, keep in mind that we're calculating the ideal 820 | * maximum chain length based on the *new* (doubled) bucket count. 821 | * In fractions this is just n/b (n=number of items,b=new num buckets). 822 | * Since the ideal chain length is an integer, we want to calculate 823 | * ceil(n/b). We don't depend on floating point arithmetic in this 824 | * hash, so to calculate ceil(n/b) with integers we could write 825 | * 826 | * ceil(n/b) = (n/b) + ((n%b)?1:0) 827 | * 828 | * and in fact a previous version of this hash did just that. 829 | * But now we have improved things a bit by recognizing that b is 830 | * always a power of two. We keep its base 2 log handy (call it lb), 831 | * so now we can write this with a bit shift and logical AND: 832 | * 833 | * ceil(n/b) = (n>>lb) + ( (n & (b-1)) ? 1:0) 834 | * 835 | */ 836 | #define HASH_EXPAND_BUCKETS(hh,tbl,oomed) \ 837 | do { \ 838 | unsigned _he_bkt; \ 839 | unsigned _he_bkt_i; \ 840 | struct UT_hash_handle *_he_thh, *_he_hh_nxt; \ 841 | UT_hash_bucket *_he_new_buckets, *_he_newbkt; \ 842 | _he_new_buckets = (UT_hash_bucket*)uthash_malloc( \ 843 | sizeof(struct UT_hash_bucket) * (tbl)->num_buckets * 2U); \ 844 | if (!_he_new_buckets) { \ 845 | HASH_RECORD_OOM(oomed); \ 846 | } else { \ 847 | uthash_bzero(_he_new_buckets, \ 848 | sizeof(struct UT_hash_bucket) * (tbl)->num_buckets * 2U); \ 849 | (tbl)->ideal_chain_maxlen = \ 850 | ((tbl)->num_items >> ((tbl)->log2_num_buckets+1U)) + \ 851 | ((((tbl)->num_items & (((tbl)->num_buckets*2U)-1U)) != 0U) ? 1U : 0U); \ 852 | (tbl)->nonideal_items = 0; \ 853 | for (_he_bkt_i = 0; _he_bkt_i < (tbl)->num_buckets; _he_bkt_i++) { \ 854 | _he_thh = (tbl)->buckets[ _he_bkt_i ].hh_head; \ 855 | while (_he_thh != NULL) { \ 856 | _he_hh_nxt = _he_thh->hh_next; \ 857 | HASH_TO_BKT(_he_thh->hashv, (tbl)->num_buckets * 2U, _he_bkt); \ 858 | _he_newbkt = &(_he_new_buckets[_he_bkt]); \ 859 | if (++(_he_newbkt->count) > (tbl)->ideal_chain_maxlen) { \ 860 | (tbl)->nonideal_items++; \ 861 | if (_he_newbkt->count > _he_newbkt->expand_mult * (tbl)->ideal_chain_maxlen) { \ 862 | _he_newbkt->expand_mult++; \ 863 | } \ 864 | } \ 865 | _he_thh->hh_prev = NULL; \ 866 | _he_thh->hh_next = _he_newbkt->hh_head; \ 867 | if (_he_newbkt->hh_head != NULL) { \ 868 | _he_newbkt->hh_head->hh_prev = _he_thh; \ 869 | } \ 870 | _he_newbkt->hh_head = _he_thh; \ 871 | _he_thh = _he_hh_nxt; \ 872 | } \ 873 | } \ 874 | uthash_free((tbl)->buckets, (tbl)->num_buckets * sizeof(struct UT_hash_bucket)); \ 875 | (tbl)->num_buckets *= 2U; \ 876 | (tbl)->log2_num_buckets++; \ 877 | (tbl)->buckets = _he_new_buckets; \ 878 | (tbl)->ineff_expands = ((tbl)->nonideal_items > ((tbl)->num_items >> 1)) ? \ 879 | ((tbl)->ineff_expands+1U) : 0U; \ 880 | if ((tbl)->ineff_expands > 1U) { \ 881 | (tbl)->noexpand = 1; \ 882 | uthash_noexpand_fyi(tbl); \ 883 | } \ 884 | uthash_expand_fyi(tbl); \ 885 | } \ 886 | } while (0) 887 | 888 | 889 | /* This is an adaptation of Simon Tatham's O(n log(n)) mergesort */ 890 | /* Note that HASH_SORT assumes the hash handle name to be hh. 891 | * HASH_SRT was added to allow the hash handle name to be passed in. */ 892 | #define HASH_SORT(head,cmpfcn) HASH_SRT(hh,head,cmpfcn) 893 | #define HASH_SRT(hh,head,cmpfcn) \ 894 | do { \ 895 | unsigned _hs_i; \ 896 | unsigned _hs_looping,_hs_nmerges,_hs_insize,_hs_psize,_hs_qsize; \ 897 | struct UT_hash_handle *_hs_p, *_hs_q, *_hs_e, *_hs_list, *_hs_tail; \ 898 | if (head != NULL) { \ 899 | _hs_insize = 1; \ 900 | _hs_looping = 1; \ 901 | _hs_list = &((head)->hh); \ 902 | while (_hs_looping != 0U) { \ 903 | _hs_p = _hs_list; \ 904 | _hs_list = NULL; \ 905 | _hs_tail = NULL; \ 906 | _hs_nmerges = 0; \ 907 | while (_hs_p != NULL) { \ 908 | _hs_nmerges++; \ 909 | _hs_q = _hs_p; \ 910 | _hs_psize = 0; \ 911 | for (_hs_i = 0; _hs_i < _hs_insize; ++_hs_i) { \ 912 | _hs_psize++; \ 913 | _hs_q = ((_hs_q->next != NULL) ? \ 914 | HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \ 915 | if (_hs_q == NULL) { \ 916 | break; \ 917 | } \ 918 | } \ 919 | _hs_qsize = _hs_insize; \ 920 | while ((_hs_psize != 0U) || ((_hs_qsize != 0U) && (_hs_q != NULL))) { \ 921 | if (_hs_psize == 0U) { \ 922 | _hs_e = _hs_q; \ 923 | _hs_q = ((_hs_q->next != NULL) ? \ 924 | HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \ 925 | _hs_qsize--; \ 926 | } else if ((_hs_qsize == 0U) || (_hs_q == NULL)) { \ 927 | _hs_e = _hs_p; \ 928 | if (_hs_p != NULL) { \ 929 | _hs_p = ((_hs_p->next != NULL) ? \ 930 | HH_FROM_ELMT((head)->hh.tbl, _hs_p->next) : NULL); \ 931 | } \ 932 | _hs_psize--; \ 933 | } else if ((cmpfcn( \ 934 | DECLTYPE(head)(ELMT_FROM_HH((head)->hh.tbl, _hs_p)), \ 935 | DECLTYPE(head)(ELMT_FROM_HH((head)->hh.tbl, _hs_q)) \ 936 | )) <= 0) { \ 937 | _hs_e = _hs_p; \ 938 | if (_hs_p != NULL) { \ 939 | _hs_p = ((_hs_p->next != NULL) ? \ 940 | HH_FROM_ELMT((head)->hh.tbl, _hs_p->next) : NULL); \ 941 | } \ 942 | _hs_psize--; \ 943 | } else { \ 944 | _hs_e = _hs_q; \ 945 | _hs_q = ((_hs_q->next != NULL) ? \ 946 | HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \ 947 | _hs_qsize--; \ 948 | } \ 949 | if ( _hs_tail != NULL ) { \ 950 | _hs_tail->next = ((_hs_e != NULL) ? \ 951 | ELMT_FROM_HH((head)->hh.tbl, _hs_e) : NULL); \ 952 | } else { \ 953 | _hs_list = _hs_e; \ 954 | } \ 955 | if (_hs_e != NULL) { \ 956 | _hs_e->prev = ((_hs_tail != NULL) ? \ 957 | ELMT_FROM_HH((head)->hh.tbl, _hs_tail) : NULL); \ 958 | } \ 959 | _hs_tail = _hs_e; \ 960 | } \ 961 | _hs_p = _hs_q; \ 962 | } \ 963 | if (_hs_tail != NULL) { \ 964 | _hs_tail->next = NULL; \ 965 | } \ 966 | if (_hs_nmerges <= 1U) { \ 967 | _hs_looping = 0; \ 968 | (head)->hh.tbl->tail = _hs_tail; \ 969 | DECLTYPE_ASSIGN(head, ELMT_FROM_HH((head)->hh.tbl, _hs_list)); \ 970 | } \ 971 | _hs_insize *= 2U; \ 972 | } \ 973 | HASH_FSCK(hh, head, "HASH_SRT"); \ 974 | } \ 975 | } while (0) 976 | 977 | /* This function selects items from one hash into another hash. 978 | * The end result is that the selected items have dual presence 979 | * in both hashes. There is no copy of the items made; rather 980 | * they are added into the new hash through a secondary hash 981 | * hash handle that must be present in the structure. */ 982 | #define HASH_SELECT(hh_dst, dst, hh_src, src, cond) \ 983 | do { \ 984 | unsigned _src_bkt, _dst_bkt; \ 985 | void *_last_elt = NULL, *_elt; \ 986 | UT_hash_handle *_src_hh, *_dst_hh, *_last_elt_hh=NULL; \ 987 | ptrdiff_t _dst_hho = ((char*)(&(dst)->hh_dst) - (char*)(dst)); \ 988 | if ((src) != NULL) { \ 989 | for (_src_bkt=0; _src_bkt < (src)->hh_src.tbl->num_buckets; _src_bkt++) { \ 990 | for (_src_hh = (src)->hh_src.tbl->buckets[_src_bkt].hh_head; \ 991 | _src_hh != NULL; \ 992 | _src_hh = _src_hh->hh_next) { \ 993 | _elt = ELMT_FROM_HH((src)->hh_src.tbl, _src_hh); \ 994 | if (cond(_elt)) { \ 995 | IF_HASH_NONFATAL_OOM( int _hs_oomed = 0; ) \ 996 | _dst_hh = (UT_hash_handle*)(void*)(((char*)_elt) + _dst_hho); \ 997 | _dst_hh->key = _src_hh->key; \ 998 | _dst_hh->keylen = _src_hh->keylen; \ 999 | _dst_hh->hashv = _src_hh->hashv; \ 1000 | _dst_hh->prev = _last_elt; \ 1001 | _dst_hh->next = NULL; \ 1002 | if (_last_elt_hh != NULL) { \ 1003 | _last_elt_hh->next = _elt; \ 1004 | } \ 1005 | if ((dst) == NULL) { \ 1006 | DECLTYPE_ASSIGN(dst, _elt); \ 1007 | HASH_MAKE_TABLE(hh_dst, dst, _hs_oomed); \ 1008 | IF_HASH_NONFATAL_OOM( \ 1009 | if (_hs_oomed) { \ 1010 | uthash_nonfatal_oom(_elt); \ 1011 | (dst) = NULL; \ 1012 | continue; \ 1013 | } \ 1014 | ) \ 1015 | } else { \ 1016 | _dst_hh->tbl = (dst)->hh_dst.tbl; \ 1017 | } \ 1018 | HASH_TO_BKT(_dst_hh->hashv, _dst_hh->tbl->num_buckets, _dst_bkt); \ 1019 | HASH_ADD_TO_BKT(_dst_hh->tbl->buckets[_dst_bkt], hh_dst, _dst_hh, _hs_oomed); \ 1020 | (dst)->hh_dst.tbl->num_items++; \ 1021 | IF_HASH_NONFATAL_OOM( \ 1022 | if (_hs_oomed) { \ 1023 | HASH_ROLLBACK_BKT(hh_dst, dst, _dst_hh); \ 1024 | HASH_DELETE_HH(hh_dst, dst, _dst_hh); \ 1025 | _dst_hh->tbl = NULL; \ 1026 | uthash_nonfatal_oom(_elt); \ 1027 | continue; \ 1028 | } \ 1029 | ) \ 1030 | HASH_BLOOM_ADD(_dst_hh->tbl, _dst_hh->hashv); \ 1031 | _last_elt = _elt; \ 1032 | _last_elt_hh = _dst_hh; \ 1033 | } \ 1034 | } \ 1035 | } \ 1036 | } \ 1037 | HASH_FSCK(hh_dst, dst, "HASH_SELECT"); \ 1038 | } while (0) 1039 | 1040 | #define HASH_CLEAR(hh,head) \ 1041 | do { \ 1042 | if ((head) != NULL) { \ 1043 | HASH_BLOOM_FREE((head)->hh.tbl); \ 1044 | uthash_free((head)->hh.tbl->buckets, \ 1045 | (head)->hh.tbl->num_buckets*sizeof(struct UT_hash_bucket)); \ 1046 | uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ 1047 | (head) = NULL; \ 1048 | } \ 1049 | } while (0) 1050 | 1051 | #define HASH_OVERHEAD(hh,head) \ 1052 | (((head) != NULL) ? ( \ 1053 | (size_t)(((head)->hh.tbl->num_items * sizeof(UT_hash_handle)) + \ 1054 | ((head)->hh.tbl->num_buckets * sizeof(UT_hash_bucket)) + \ 1055 | sizeof(UT_hash_table) + \ 1056 | (HASH_BLOOM_BYTELEN))) : 0U) 1057 | 1058 | #ifdef NO_DECLTYPE 1059 | #define HASH_ITER(hh,head,el,tmp) \ 1060 | for(((el)=(head)), ((*(char**)(&(tmp)))=(char*)((head!=NULL)?(head)->hh.next:NULL)); \ 1061 | (el) != NULL; ((el)=(tmp)), ((*(char**)(&(tmp)))=(char*)((tmp!=NULL)?(tmp)->hh.next:NULL))) 1062 | #else 1063 | #define HASH_ITER(hh,head,el,tmp) \ 1064 | for(((el)=(head)), ((tmp)=DECLTYPE(el)((head!=NULL)?(head)->hh.next:NULL)); \ 1065 | (el) != NULL; ((el)=(tmp)), ((tmp)=DECLTYPE(el)((tmp!=NULL)?(tmp)->hh.next:NULL))) 1066 | #endif 1067 | 1068 | /* obtain a count of items in the hash */ 1069 | #define HASH_COUNT(head) HASH_CNT(hh,head) 1070 | #define HASH_CNT(hh,head) ((head != NULL)?((head)->hh.tbl->num_items):0U) 1071 | 1072 | typedef struct UT_hash_bucket { 1073 | struct UT_hash_handle *hh_head; 1074 | unsigned count; 1075 | 1076 | /* expand_mult is normally set to 0. In this situation, the max chain length 1077 | * threshold is enforced at its default value, HASH_BKT_CAPACITY_THRESH. (If 1078 | * the bucket's chain exceeds this length, bucket expansion is triggered). 1079 | * However, setting expand_mult to a non-zero value delays bucket expansion 1080 | * (that would be triggered by additions to this particular bucket) 1081 | * until its chain length reaches a *multiple* of HASH_BKT_CAPACITY_THRESH. 1082 | * (The multiplier is simply expand_mult+1). The whole idea of this 1083 | * multiplier is to reduce bucket expansions, since they are expensive, in 1084 | * situations where we know that a particular bucket tends to be overused. 1085 | * It is better to let its chain length grow to a longer yet-still-bounded 1086 | * value, than to do an O(n) bucket expansion too often. 1087 | */ 1088 | unsigned expand_mult; 1089 | 1090 | } UT_hash_bucket; 1091 | 1092 | /* random signature used only to find hash tables in external analysis */ 1093 | #define HASH_SIGNATURE 0xa0111fe1u 1094 | #define HASH_BLOOM_SIGNATURE 0xb12220f2u 1095 | 1096 | typedef struct UT_hash_table { 1097 | UT_hash_bucket *buckets; 1098 | unsigned num_buckets, log2_num_buckets; 1099 | unsigned num_items; 1100 | struct UT_hash_handle *tail; /* tail hh in app order, for fast append */ 1101 | ptrdiff_t hho; /* hash handle offset (byte pos of hash handle in element */ 1102 | 1103 | /* in an ideal situation (all buckets used equally), no bucket would have 1104 | * more than ceil(#items/#buckets) items. that's the ideal chain length. */ 1105 | unsigned ideal_chain_maxlen; 1106 | 1107 | /* nonideal_items is the number of items in the hash whose chain position 1108 | * exceeds the ideal chain maxlen. these items pay the penalty for an uneven 1109 | * hash distribution; reaching them in a chain traversal takes >ideal steps */ 1110 | unsigned nonideal_items; 1111 | 1112 | /* ineffective expands occur when a bucket doubling was performed, but 1113 | * afterward, more than half the items in the hash had nonideal chain 1114 | * positions. If this happens on two consecutive expansions we inhibit any 1115 | * further expansion, as it's not helping; this happens when the hash 1116 | * function isn't a good fit for the key domain. When expansion is inhibited 1117 | * the hash will still work, albeit no longer in constant time. */ 1118 | unsigned ineff_expands, noexpand; 1119 | 1120 | uint32_t signature; /* used only to find hash tables in external analysis */ 1121 | #ifdef HASH_BLOOM 1122 | uint32_t bloom_sig; /* used only to test bloom exists in external analysis */ 1123 | uint8_t *bloom_bv; 1124 | uint8_t bloom_nbits; 1125 | #endif 1126 | 1127 | } UT_hash_table; 1128 | 1129 | typedef struct UT_hash_handle { 1130 | struct UT_hash_table *tbl; 1131 | void *prev; /* prev element in app order */ 1132 | void *next; /* next element in app order */ 1133 | struct UT_hash_handle *hh_prev; /* previous hh in bucket order */ 1134 | struct UT_hash_handle *hh_next; /* next hh in bucket order */ 1135 | const void *key; /* ptr to enclosing struct's key */ 1136 | unsigned keylen; /* enclosing struct's key len */ 1137 | unsigned hashv; /* result of hash-fcn(key) */ 1138 | } UT_hash_handle; 1139 | 1140 | #endif /* UTHASH_H */ 1141 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from operator import itemgetter 5 | import torch.nn.utils 6 | from surel_gacc import sjoin 7 | from sklearn.metrics import roc_auc_score 8 | from torch.nn import BCEWithLogitsLoss 9 | 10 | from utils import * 11 | 12 | 13 | def train(model, opti, data, dT): 14 | model.train() 15 | total_loss = 0 16 | labels, preds = [], [] 17 | for wl, wr, label, x in data: 18 | labels.append(label) 19 | Tf = torch.stack([dT[wl], dT[wr]]) 20 | opti.zero_grad() 21 | pred = model(Tf, [wl, wr]) 22 | preds.append(pred.detach().sigmoid()) 23 | target = label.to(pred.device) 24 | loss = BCEWithLogitsLoss()(pred, target) 25 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 26 | loss.backward() 27 | opti.step() 28 | total_loss += loss.item() * len(label) 29 | predictions = torch.cat(preds).cpu() 30 | labels = torch.cat(labels) 31 | return total_loss / len(labels), roc_auc_score(labels, predictions) 32 | 33 | 34 | def eval_model(model, x_dict, x_set, args, evaluator, device, mode='test', return_predictions=False): 35 | model.eval() 36 | preds = [] 37 | with torch.no_grad(): 38 | x_embed, target = x_set['X'], x_set[mode]['E'] 39 | with tqdm(total=len(target)) as pbar: 40 | for batch in gen_batch(target, args.batch_num, keep=True): 41 | Bs = torch.unique(batch).numpy() 42 | S, K, F = zip(*itemgetter(*Bs)(x_dict)) 43 | S = torch.from_numpy(np.asarray(S)).long() 44 | F = np.concatenate(F) 45 | F = np.concatenate([[[0] * F.shape[-1]], F]) 46 | mF = torch.from_numpy(F).to(device) 47 | uvw, uvx = sjoin(S, K, batch, return_idx=True) 48 | uvw = uvw.reshape(2, -1, 2) 49 | x = torch.from_numpy(uvw) 50 | gT = normalization(mF, args) 51 | gT = torch.stack([gT[uvw[0]], gT[uvw[1]]]) 52 | pred = model(gT, x) 53 | preds.append(pred.sigmoid()) 54 | pbar.update(len(pred)) 55 | predictions = torch.cat(preds, dim=0) 56 | 57 | if not return_predictions: 58 | labels = torch.zeros(len(predictions)) 59 | result_dict = {'metric': args.metric, 'mode': mode} 60 | if args.metric == 'mrr': 61 | num_pos = x_set[mode]['num_pos'] 62 | labels[:num_pos] = 1 63 | pred_pos, pred_neg = predictions[:num_pos], predictions[num_pos:] 64 | result_dict['mrr_list'] = \ 65 | evaluator.eval({"y_pred_pos": pred_pos.view(-1), "y_pred_neg": pred_neg.view(num_pos, -1)})['mrr_list'] 66 | elif 'Hits' in args.metric: 67 | num_neg = x_set[mode]['num_neg'] 68 | labels[num_neg:] = 1 69 | pred_neg, pred_pos = predictions[:num_neg], predictions[num_neg:] 70 | result_dict['hits'] = evaluate_hits(pred_pos.view(-1), pred_neg.view(-1), evaluator) 71 | result_dict['num_pos'] = len(pred_pos) 72 | else: 73 | raise NotImplementedError 74 | 75 | result_dict['auc'] = roc_auc_score(labels, predictions.cpu()) 76 | 77 | return result_dict 78 | else: 79 | return predictions 80 | 81 | 82 | def eval_model_horder(model, x_dict, x_set, args, evaluator, device, mode='test', return_predictions=False): 83 | model.eval() 84 | preds = [] 85 | with torch.no_grad(): 86 | x_embed, target = x_set['X'], x_set[mode]['E'] 87 | with tqdm(total=len(target)) as pbar: 88 | for batch in gen_batch(target, args.batch_num, keep=True): 89 | Bs = torch.unique(batch).numpy() 90 | S, K, F = zip(*itemgetter(*Bs)(x_dict)) 91 | S = torch.from_numpy(np.asarray(S)).long() 92 | F = np.concatenate(F) 93 | F = np.concatenate([[[0] * F.shape[-1]], F]) 94 | mF = torch.from_numpy(F).to(device) 95 | uw = sjoin(S, K, batch[:, [0, 2]], return_idx=False) 96 | vw = sjoin(S, K, batch[:, [1, 2]], return_idx=False) 97 | uvw = np.concatenate([uw, vw], axis=1).reshape(2, -1, 2) 98 | x = torch.from_numpy(uvw) 99 | gT = normalization(mF, args) 100 | gT = torch.stack([gT[uvw[0]], gT[uvw[1]]]) 101 | pred = model(gT, x) 102 | preds.append(pred.sigmoid()) 103 | pbar.update(len(pred)) 104 | predictions = torch.cat(preds, dim=0) 105 | 106 | if not return_predictions: 107 | labels = torch.zeros(len(predictions)) 108 | result_dict = {'metric': args.metric, 'mode': mode} 109 | if args.metric == 'mrr': 110 | num_pos = x_set[mode]['num_pos'] 111 | labels[:num_pos] = 1 112 | pred_pos, pred_neg = predictions[:num_pos], predictions[num_pos:] 113 | result_dict['mrr_list'] = \ 114 | evaluator.eval({"y_pred_pos": pred_pos.view(-1), "y_pred_neg": pred_neg.view(num_pos, -1)})['mrr_list'] 115 | elif 'Hits' in args.metric: 116 | num_neg = x_set[mode]['num_neg'] 117 | labels[num_neg:] = 1 118 | pred_neg, pred_pos = predictions[:num_neg], predictions[num_neg:] 119 | result_dict['hits'] = evaluate_hits(pred_pos.view(-1), pred_neg.view(-1), evaluator) 120 | result_dict['num_pos'] = len(pred_pos) 121 | else: 122 | raise NotImplementedError 123 | 124 | result_dict['auc'] = roc_auc_score(labels, predictions.cpu()) 125 | 126 | return result_dict 127 | else: 128 | return predictions 129 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | import torch 8 | from surel_gacc import run_walk 9 | from tqdm import tqdm 10 | 11 | def set_random_seed(args): 12 | seed = args.seed 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | os.environ['PYTHONHASHSEED'] = str(seed) 20 | 21 | def gen_batch(iterable, n=1, keep=False): 22 | length = len(iterable) 23 | if keep: 24 | for ndx in range(0, length, n): 25 | yield iterable[ndx:min(ndx + n, length)] 26 | else: 27 | for ndx in range(0, length - n, n): 28 | yield iterable[ndx:min(ndx + n, length)] 29 | 30 | 31 | def np_sampling(rw_dict, ptr, neighs, bsize, target, num_walks=100, num_steps=3): 32 | with tqdm(total=len(target)) as pbar: 33 | for batch in gen_batch(target, bsize, True): 34 | walk_set, freqs = run_walk(ptr, neighs, batch, num_walks=num_walks, num_steps=num_steps, replacement=True) 35 | node_id, node_freq = freqs[:, 0], freqs[:, 1] 36 | rw_dict.update(dict(zip(batch, zip(walk_set, node_id, node_freq)))) 37 | pbar.update(len(batch)) 38 | return rw_dict 39 | 40 | 41 | def sample(high: int, size: int, device=None): 42 | size = min(high, size) 43 | return torch.tensor(random.sample(range(high), size), device=device) 44 | 45 | 46 | def coarse(Tx, K): 47 | # repeat base as the length of unique nodes appeared in its set of walks 48 | xid = torch.from_numpy(np.arange(len(Tx)).repeat(list(map(len, K)))) 49 | yid = np.concatenate(K) 50 | Ty = sorted(set(yid)) 51 | num_nodes, base = len(Tx), len(Ty) 52 | 53 | # remapping root nodes 54 | xm = -torch.ones(max(Tx) + 1, dtype=torch.long) 55 | xm[Tx] = torch.arange(num_nodes) 56 | # remapping walk nodes 57 | ym = -torch.ones(max(yid) + 1, dtype=torch.long) 58 | ym[Ty] = torch.arange(base) 59 | 60 | mB = torch.zeros(num_nodes * base, dtype=torch.long) 61 | mB[xid * base + ym[yid]] = torch.arange(len(yid)) + 1 62 | 63 | return xm, ym, base, mB 64 | 65 | 66 | def gen_sample(S, Tx, K, pos_edges, full_edges, x_embed, args, gtype='Homogeneous'): 67 | unit_size = args.num_step * args.num_walk 68 | num_nodes = len(Tx) 69 | # for hetero graph 70 | if gtype != 'Homogeneous': 71 | Tx = sorted(Tx) 72 | Ws = torch.from_numpy(S).long() 73 | xm, ym, base, mB = coarse(Tx, K) 74 | xr = torch.tensor(Tx, dtype=torch.long) 75 | 76 | mA = torch.zeros([num_nodes, num_nodes], dtype=torch.long) 77 | row, col = xm[torch.cat([full_edges, full_edges[[1, 0]]], dim=-1)] 78 | mA[row, col] = 1 79 | if gtype == 'Homogeneous': 80 | mA = mA @ mA + mA 81 | # else: 82 | # mA = torch.zeros([num_nodes, num_nodes], dtype=torch.long) 83 | # row, col = xm[torch.cat([full_edges, full_edges[[1, 0]]], dim=-1)] 84 | # pivot = xm[full_edges[0].min()] 85 | # mA[row, col] = 1 86 | # mA[:pivot, :pivot] = 1 87 | # mA[pivot:, pivot:] = 1 88 | perm = torch.arange(num_nodes * num_nodes)[~ (mA.view(-1) > 0)] 89 | neg_pair = torch.vstack([torch.div(perm, num_nodes, rounding_mode='floor'), perm % num_nodes]).t() 90 | perms = sample(len(neg_pair), args.k * num_nodes) 91 | neg_edges = neg_pair[perms].t() 92 | 93 | edge_pairs = torch.cat([xm[pos_edges], neg_edges], dim=1) 94 | labels = torch.zeros(edge_pairs.shape[1]) 95 | labels[:pos_edges.shape[1]] = 1 96 | idx = np.random.permutation(len(labels)) 97 | 98 | batch_size = args.batch_size * (args.k + 1) 99 | 100 | for bidx in gen_batch(idx, batch_size, keep=True): 101 | batch = edge_pairs[:, bidx] 102 | uidx, vidx = batch 103 | ubase = (uidx * base).repeat_interleave(unit_size) 104 | vbase = (vidx * base).repeat_interleave(unit_size) 105 | Wu, Wv = Ws[uidx].view(-1), Ws[vidx].view(-1) 106 | u_offset, v_offset = ym[Wu], ym[Wv] 107 | wu = mB[torch.stack([ubase + u_offset, vbase + u_offset], dim=-1)] 108 | wv = mB[torch.stack([ubase + v_offset, vbase + v_offset], dim=-1)] 109 | 110 | if x_embed is not None: 111 | if args.use_degree or args.use_htype: 112 | yield wu, wv, labels[bidx], (x_embed[torch.stack([Wu, Wv])], x_embed[xr[batch]]) 113 | else: 114 | yield wu, wv, labels[bidx], x_embed 115 | else: 116 | yield wu, wv, labels[bidx], None 117 | 118 | 119 | def gen_tuple(W, Tx, S, pos_tuple, args): 120 | unit_size = args.num_step * args.num_walk 121 | num_pos = len(pos_tuple) 122 | Ws = torch.from_numpy(W).long() 123 | xm, ym, base, mB = coarse(Tx, S) 124 | 125 | # do trivial random sampling 126 | dst_neg = torch.tensor([np.random.choice(Tx, args.k, replace=False) for _ in range(num_pos)]) 127 | src_neg = pos_tuple[:, :2].repeat(1, args.k).view(-1, 2) 128 | neg_tuple = torch.cat([src_neg, dst_neg.view(-1, 1)], dim=1) 129 | neg_label = [i + num_pos for i, t in enumerate(neg_tuple) if torch.all(pos_tuple == t, dim=1).sum() > 0] 130 | tuples = xm[torch.cat([pos_tuple, neg_tuple]).t()] 131 | labels = torch.zeros(tuples.shape[1]) 132 | labels[:num_pos] = 1 133 | labels[neg_label] = 1 134 | idx = np.random.permutation(len(labels)) 135 | 136 | batch_size = args.batch_size * (args.k + 1) 137 | 138 | for bidx in gen_batch(idx, batch_size, keep=True): 139 | batch = tuples[:, bidx] 140 | uidx, vidx, widx = batch 141 | ubase = (uidx * base).repeat_interleave(unit_size) 142 | vbase = (vidx * base).repeat_interleave(unit_size) 143 | wbase = (widx * base).repeat_interleave(unit_size) 144 | Xu, Xv, Xw = Ws[uidx].view(-1), Ws[vidx].view(-1), Ws[widx].view(-1) 145 | u_offset, v_offset, w_offset = ym[Xu], ym[Xv], ym[Xw] 146 | wu = mB[torch.stack([ubase + u_offset, wbase + u_offset], dim=-1)] 147 | uw = mB[torch.stack([ubase + w_offset, wbase + w_offset], dim=-1)] 148 | wv = mB[torch.stack([vbase + v_offset, wbase + v_offset], dim=-1)] 149 | vw = mB[torch.stack([vbase + w_offset, wbase + w_offset], dim=-1)] 150 | 151 | yield torch.cat([wu, wv]), torch.cat([uw, vw]), labels[bidx], None 152 | 153 | 154 | def normalization(T, args): 155 | if args.use_weight: 156 | norm = torch.tensor([args.num_walk] * args.num_step + [args.w_max], device=T.device) 157 | else: 158 | if args.norm == 'all': 159 | norm = args.num_walk 160 | elif args.norm == 'root': 161 | norm = torch.tensor([args.num_walk] + [1] * args.num_step, device=T.device) 162 | else: 163 | raise NotImplementedError 164 | return T / norm 165 | 166 | 167 | # from https://github.com/facebookresearch/SEAL_OGB 168 | def get_pos_neg_edges(split, split_edge, ratio=1.0, keep_neg=False): 169 | if 'source_node' in split_edge['train']: 170 | source = split_edge[split]['source_node'] 171 | target = split_edge[split]['target_node'] 172 | target_neg = split_edge[split]['target_node_neg'] 173 | # subsample 174 | np.random.seed(123) 175 | num_source = source.size(0) 176 | perm = np.random.permutation(num_source) 177 | perm = perm[:int(ratio * num_source)] 178 | source, target, target_neg = source[perm], target[perm], target_neg[perm, :] 179 | pos_edge = torch.stack([source, target]) 180 | neg_per_target = target_neg.size(1) 181 | neg_edge = torch.stack([source.repeat_interleave(neg_per_target), target_neg.view(-1)]) 182 | elif 'edge' in split_edge['train']: 183 | pos_edge = split_edge[split]['edge'].t() 184 | neg_edge = split_edge[split]['edge_neg'].t() 185 | # subsample for pos_edge 186 | if ratio < 1: 187 | np.random.seed(123) 188 | num_pos = pos_edge.size(1) 189 | perm = np.random.permutation(num_pos) 190 | perm = perm[:int(ratio * num_pos)] 191 | pos_edge = pos_edge[:, perm] 192 | # subsample for neg_edge 193 | if not keep_neg: 194 | np.random.seed(123) 195 | num_neg = neg_edge.size(1) 196 | if num_neg // num_pos == 1000: 197 | neg_edge = neg_edge.t().view(num_pos, -1, 2)[perm].reshape(-1, 2).t() 198 | else: 199 | perm = np.random.permutation(num_neg) 200 | perm = perm[:int(ratio * num_neg)] 201 | neg_edge = neg_edge[:, perm] 202 | elif 'hedge' in split_edge['train']: 203 | pos_edge = split_edge[split]['hedge'].t() 204 | neg_edge = split_edge[split]['hedge_neg'] 205 | if ratio < 1: 206 | np.random.seed(123) 207 | num_pos = pos_edge.size(1) 208 | perm = np.random.permutation(num_pos) 209 | perm = perm[:int(ratio * num_pos)] 210 | pos_edge = pos_edge[:, perm] 211 | neg_edge = neg_edge.view(num_pos, -1, 3)[perm].reshape(-1, 3) 212 | pos_edge = pos_edge.t() 213 | else: 214 | raise NotImplementedError 215 | return pos_edge, neg_edge 216 | 217 | 218 | def evaluate_hits(pos_pred, neg_pred, evaluator): 219 | results = {} 220 | for K in [10, 20, 50, 100]: 221 | evaluator.K = K 222 | res_hits = evaluator.eval({ 223 | 'y_pred_pos': pos_pred, 224 | 'y_pred_neg': neg_pred, 225 | })[f'hits@{K}'] 226 | 227 | results[f'Hits@{K}'] = res_hits 228 | return results 229 | 230 | 231 | def save_checkpoint(state, filename='checkpoint'): 232 | print("=> Saving checkpoint") 233 | torch.save(state, f'{filename}.pth.tar') 234 | 235 | 236 | def load_checkpoint(model, optimizer, filename): 237 | checkpoint = torch.load(f'{filename}.pth.tar') 238 | print(f"<= Loading checkpoint from epoch {checkpoint['epoch']}") 239 | model.load_state_dict(checkpoint['state_dict']) 240 | optimizer.load_state_dict(checkpoint['optimizer']) 241 | --------------------------------------------------------------------------------