├── mxgraph ├── __init__.py ├── helpers │ ├── __init__.py │ ├── metric_logger.py │ ├── ordered_easydict.py │ └── email_sender.py ├── layers │ ├── __init__.py │ ├── loss.py │ ├── common.py │ ├── graph_rnn.py │ └── stack_layers.py ├── cfg_helper.py ├── utils.py ├── config.py └── sampler.py ├── .gitignore ├── experiments └── static_graph │ ├── cfg_ppi_sup │ ├── run.sh │ ├── 1206_ppi_20_sup_d0_pool_avg_128_512_1_128_512_1_all_2_d0.10.1.yml │ ├── 1206_ppi_20_sup_d0_pool_max_128_512_1_128_512_1_all_2_d0.10.1.yml │ ├── 100_ppi_20_sup_d0_mugga_g0_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml │ ├── 100_ppi_20_sup_d0_mugga_g1_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml │ ├── 1206_ppi_20_sup_d0_mugga_g1_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml │ ├── 200_ppi_20_sup_d0_mugga_g0_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml │ ├── 200_ppi_20_sup_d0_mugga_g1_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml │ ├── 1206_ppi_20_sup_d0_mugga_g0_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml │ ├── 1206_ppi_20_sup_d0_multi_weighted_div1_tanh_128_64_24_8_128_64_24_8_all_2_d0.10.1.yml │ └── 1206_ppi_20_sup_d0_multi_weighted_div1_sigmoid_128_64_24_8_128_64_24_8_all_2_d0.10.1.yml │ ├── README.md │ ├── cfg_reddit_sup │ ├── 215_reddit_sup_d0_mugga_g0_128_32_512_1_64_1_128_32_256_2_64_1_fixed_25_10_d0.10.1.yml │ ├── 215_reddit_sup_d0_mugga_g1_128_32_512_1_64_1_128_32_128_4_64_1_fixed_25_10_d0.10.1.yml │ ├── 215_reddit_sup_d0_multi_weighted_div1_tanh_128_256_32_4_128_256_32_4_fixed_25_10_d0.10.1.yml │ ├── 215_reddit_sup_d0_multi_weighted_div1_sigmoid_128_256_32_4_128_256_32_4_fixed_25_10_d0.10.1.yml │ ├── 215_reddit_sup_d0_pool_avg_128_1024_1_128_1024_1_fixed_25_10_d0.10.1.yml │ └── 215_reddit_sup_d0_pool_max_128_1024_1_128_1024_1_fixed_25_10_d0.10.1.yml │ └── sup_train_sample.py ├── GraphSampler ├── test_sampler.py ├── CMakeLists.txt ├── README.md ├── install.py ├── cmake │ └── Modules │ │ └── FindNumpy.cmake ├── graph_sampler.h └── py_ext.cpp ├── setup.py ├── seg_ops_cuda ├── CMakeLists.txt └── README.md ├── README.md └── download_data.py /mxgraph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mxgraph/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /mxgraph/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .aggregators import * 3 | from .stack_layers import * 4 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python ../sup_train_sample.py --cfg 1206_ppi_20_sup_d0_mugga_g1_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml --ctx gpu0 4 | -------------------------------------------------------------------------------- /experiments/static_graph/README.md: -------------------------------------------------------------------------------- 1 | # Train 2 | 3 | - supervised training 4 | ```bash 5 | python sup_train_sample.py --cfg yourcfgfilename --ctx gpu0 6 | ``` 7 | 8 | The sample cfg files are in the cfg_ppi_sup dir and cfg_reddit_sup dir. -------------------------------------------------------------------------------- /mxgraph/helpers/metric_logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from collections import OrderedDict 4 | 5 | class MetricLogger(object): 6 | def __init__(self, attr_names, parse_formats, save_path): 7 | self._attr_format_dict = OrderedDict(zip(attr_names, parse_formats)) 8 | self._file = open(save_path, 'w') 9 | self._csv = csv.writer(self._file) 10 | self._csv.writerow(attr_names) 11 | self._file.flush() 12 | 13 | def log(self, **kwargs): 14 | self._csv.writerow([parse_format % kwargs[attr_name] 15 | for attr_name, parse_format in self._attr_format_dict.items()]) 16 | self._file.flush() 17 | -------------------------------------------------------------------------------- /GraphSampler/test_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mxgraph.iterators import cfg_data_loader 3 | from mxgraph.sampler import FixedNeighborSampler 4 | from mxgraph.graph import set_seed 5 | set_seed(100) 6 | 7 | G, features, _, _, _ = cfg_data_loader() 8 | sampler = FixedNeighborSampler(layer_num=2, neighbor_num=[4, 2]) 9 | indices_in_merged_l, end_points_l, indptr_l, node_ids_l = sampler.sample_by_indices(G, np.arange(10)) 10 | print("node_ids_l", node_ids_l[0].shape, node_ids_l[1].shape, node_ids_l[2].shape) 11 | print(node_ids_l) 12 | print("indptr_l", indptr_l[0].shape, indptr_l[1].shape) 13 | print(indptr_l) 14 | print("end_points_l", end_points_l[0].shape, end_points_l[0].shape) 15 | print(end_points_l) 16 | -------------------------------------------------------------------------------- /GraphSampler/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | project( GraphSampler ) 3 | if(UNIX) 4 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -std=c++11 -O3 -ffast-math") 5 | endif() 6 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/") 7 | set(CMAKE_CXX_STANDARD 11) 8 | find_package(PythonInterp 3 REQUIRED) 9 | find_package(PythonLibs 3 REQUIRED) 10 | find_package(Numpy REQUIRED) 11 | include_directories(SYSTEM ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) 12 | find_package(OpenMP) 13 | if (OPENMP_FOUND) 14 | set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 15 | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 16 | endif() 17 | add_library(graph_sampler SHARED py_ext.cpp graph_sampler.cpp) 18 | target_link_libraries(graph_sampler ${PYTHON_LIBRARIES}) 19 | add_executable( graph_sampler_test graph_sampler.cpp) 20 | -------------------------------------------------------------------------------- /GraphSampler/README.md: -------------------------------------------------------------------------------- 1 | # C++ Extensions for Graph Sampler in Python 2 | 3 | The sampler of the graph. 4 | 5 | The graph is assumed to have this format 6 | - node_types: ... 7 | - end_points: ... 8 | - ind_ptr: ... 9 | - node_ids: ... 10 | 11 | The sampled_graph will have this format 12 | - Sample a subgraph from the given graph 13 | - Sample a subgraph from the given graph w.r.t some given nodes 14 | 15 | # Install 16 | For windows users: 17 | 18 | ```bash 19 | mkdir build 20 | cd build 21 | cmake -G "Visual Studio 14 2015 Win64" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CONFIGURATION_TYPES="Release" .. 22 | ``` 23 | Open GraphSampler.sln and use VS 2015 to build, then 24 | ```bash 25 | cd .. 26 | python install.py 27 | ``` 28 | 29 | For unix users (including macOS and linux): 30 | ```bash 31 | mkdir build 32 | cd build 33 | cmake .. 34 | make 35 | cd .. 36 | python install.py 37 | ``` 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | 4 | setuptools.setup( 5 | name='GaAN', 6 | version="0.1.dev0", 7 | author="Jiani Zhang, Xingjian Shi", 8 | author_email="jnzhang@cuhk.edu.hk, xshiab@cse.ust.hk", 9 | packages=setuptools.find_packages(), 10 | description='GluonGraph', 11 | long_description=open(os.path.join(os.path.dirname( 12 | os.path.abspath(__file__)), 'README.md')).read(), 13 | license='MIT', 14 | url='https://github.com/sxjscience/MXGraph', 15 | install_requires=['numpy', 'scipy', 'matplotlib', 'six', 'pyyaml', 'networkx', 'sklearn', 'pandas'], 16 | classifiers=['Development Status :: 2 - Pre-Alpha', 17 | 'Intended Audience :: Science/Research', 18 | 'License :: OSI Approved :: MIT License', 19 | 'Natural Language :: English', 20 | 'Operating System :: OS Independent', 21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'], 22 | ) -------------------------------------------------------------------------------- /seg_ops_cuda/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | project(seg_ops) 3 | 4 | LIST(APPEND CMAKE_CXX_FLAGS "-std=c++11 -O3 -ffast-math -Wall") 5 | find_package(OpenMP REQUIRED) 6 | if (OPENMP_FOUND) 7 | set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 8 | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 9 | endif() 10 | 11 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_30,code=sm_30) 12 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_35,code=sm_35) 13 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_50,code=sm_50) 14 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_52,code=sm_52) 15 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_60,code=sm_60) 16 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_61,code=sm_61) 17 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_62,code=sm_62) 18 | LIST(APPEND CUDA_NVCC_FLAGS -gencode arch=compute_70,code=sm_70) 19 | 20 | find_package(CUDA REQUIRED) 21 | include_directories(cub) 22 | # Specify binary name and source file to build it from 23 | 24 | cuda_add_executable( 25 | seg_ops_test 26 | seg_ops.cu) 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GaAN 2 | 3 | The MXNet implementation of GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs 4 | in UAI 2018. 5 | 6 | 7 | We only support python3! 8 | 9 | ## Installation 10 | 11 | Compile the MXNet operators by following the guide in [seg_ops_cuda](seg_ops_cuda). 12 | Install the graph sampler by following the guide in [GraphSampler](GraphSampler). 13 | 14 | ```bash 15 | python setup.py develop 16 | ``` 17 | 18 | ## Download datasets 19 | You can download the datasets via the *download_data.py script. The usage is like 20 | ```bash 21 | python download_data.py --dataset ppi 22 | ``` 23 | The --dataset hyperparameter can be 'cora', 'ppi', and 'reddit'. 24 | 25 | ## Run experiments 26 | The script is experiments/static_graph/sup_train_sample.py. 27 | 28 | ## Citation 29 | ``` 30 | @inproceedings{zhang18, 31 | author = {Jiani Zhang and Xingjian Shi and Junyuan Xie and Hao Ma and Irwin King and Dit{-}Yan Yeung}, 32 | title = {GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs}, 33 | booktitle = {Proceedings of the Thirty-Fourth Conference on Uncertainty in Artificial Intelligence}, 34 | pages = {339--349}, 35 | year = {2018} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /GraphSampler/install.py: -------------------------------------------------------------------------------- 1 | def load_c_plugin(): 2 | import os 3 | import shutil 4 | from sys import platform 5 | 6 | _BASE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 7 | 8 | # Copy the graph_sampler python lib to the destination 9 | if platform.lower().startswith("win"): 10 | _VALID_DLL_PATHS = [os.path.join(_BASE_PATH, 'GraphSampler', 'build', 'Release', 'graph_sampler.dll')] 11 | _TARGET_PATH = os.path.join(_BASE_PATH, 'mxgraph', '_graph_sampler.pyd') 12 | else: 13 | _VALID_DLL_PATHS = [os.path.join(_BASE_PATH, 'GraphSampler', 'build', 'libgraph_sampler.dylib'), 14 | os.path.join(_BASE_PATH, 'GraphSampler', 'build', 'libgraph_sampler.so')] 15 | _TARGET_PATH = os.path.join(_BASE_PATH, 'mxgraph', '_graph_sampler.so') 16 | 17 | found = False 18 | for p in _VALID_DLL_PATHS: 19 | if os.path.exists(p): 20 | found = True 21 | print("Found python extension for graph sampling, path=%s. Copy to %s" % (p, _TARGET_PATH)) 22 | shutil.copy(p, _TARGET_PATH) 23 | break 24 | if not found: 25 | raise RuntimeError( 26 | "Graph sampling extensions not found! Please check these paths: %s" % (str(_VALID_DLL_PATHS))) 27 | load_c_plugin() 28 | -------------------------------------------------------------------------------- /mxgraph/helpers/ordered_easydict.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | class OrderedEasyDict(OrderedDict): 5 | """Using OrderedDict for the `easydict` package 6 | See Also https://pypi.python.org/pypi/easydict/ 7 | """ 8 | def __init__(self, d=None, **kwargs): 9 | super(OrderedEasyDict, self).__init__() 10 | if d is None: 11 | d = OrderedDict() 12 | if kwargs: 13 | d.update(**kwargs) 14 | for k, v in d.items(): 15 | setattr(self, k, v) 16 | # Class attributes 17 | for k in self.__class__.__dict__.keys(): 18 | if not (k.startswith('__') and k.endswith('__')): 19 | setattr(self, k, getattr(self, k)) 20 | 21 | def __setattr__(self, name, value): 22 | # special handling of self.__root and self.__map 23 | if name.startswith('_') and (name.endswith('__root') or name.endswith('__map')): 24 | super(OrderedEasyDict, self).__setattr__(name, value) 25 | else: 26 | if isinstance(value, (list, tuple)): 27 | value = [self.__class__(x) 28 | if isinstance(x, dict) else x for x in value] 29 | else: 30 | value = self.__class__(value) if isinstance(value, dict) else value 31 | super(OrderedEasyDict, self).__setattr__(name, value) 32 | super(OrderedEasyDict, self).__setitem__(name, value) 33 | 34 | __setitem__ = __setattr__ 35 | 36 | if __name__ == "__main__": 37 | import doctest 38 | doctest.testmod() -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/1206_ppi_20_sup_d0_pool_avg_128_512_1_128_512_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: avg 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: true 26 | USE_SHARPNESS: true 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - GraphPoolAggregator 34 | - [128, 512, 1] 35 | - - GraphPoolAggregator 36 | - [128, 512, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/1206_ppi_20_sup_d0_pool_max_128_512_1_128_512_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: true 26 | USE_SHARPNESS: true 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - GraphPoolAggregator 34 | - [128, 512, 1] 35 | - - GraphPoolAggregator 36 | - [128, 512, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/100_ppi_20_sup_d0_mugga_g0_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: false 26 | USE_SHARPNESS: false 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - MuGGA 34 | - [128, 24, 32, 8, 64, 1] 35 | - - MuGGA 36 | - [128, 24, 32, 8, 64, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/100_ppi_20_sup_d0_mugga_g1_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: true 26 | USE_SHARPNESS: false 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - MuGGA 34 | - [128, 24, 32, 8, 64, 1] 35 | - - MuGGA 36 | - [128, 24, 32, 8, 64, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/1206_ppi_20_sup_d0_mugga_g1_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: true 26 | USE_SHARPNESS: false 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - MuGGA 34 | - [128, 24, 32, 8, 64, 1] 35 | - - MuGGA 36 | - [128, 24, 32, 8, 64, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/200_ppi_20_sup_d0_mugga_g0_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: false 26 | USE_SHARPNESS: false 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - MuGGA 34 | - [128, 24, 32, 8, 64, 1] 35 | - - MuGGA 36 | - [128, 24, 32, 8, 64, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/200_ppi_20_sup_d0_mugga_g1_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: true 26 | USE_SHARPNESS: false 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - MuGGA 34 | - [128, 24, 32, 8, 64, 1] 35 | - - MuGGA 36 | - [128, 24, 32, 8, 64, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/1206_ppi_20_sup_d0_mugga_g0_128_24_32_8_64_1_128_24_32_8_64_1_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | MUGGA: 18 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 19 | USE_EDGE: false 20 | ATTEND_W_DROPOUT: 0.0 21 | CONTEXT: 22 | USE_SUM_POOL: false 23 | USE_MAX_POOL: true 24 | USE_AVG_POOL: true 25 | USE_GATE: false 26 | USE_SHARPNESS: false 27 | STATIC_GRAPH: 28 | MODEL: 29 | TYP: supervised 30 | FEATURE_NORMALIZE: false 31 | FIRST_EMBED_UNITS: 64 32 | AGGREGATOR_ARGS_LIST: 33 | - - MuGGA 34 | - [128, 24, 32, 8, 64, 1] 35 | - - MuGGA 36 | - [128, 24, 32, 8, 64, 1] 37 | DROPOUT_RATE_LIST: [0.1, 0.1] 38 | DENSE_CONNECT: false 39 | L2_NORMALIZATION: false 40 | EVERY_LAYER_L2_NORMALIZATION: false 41 | EMBED_DIM: 64 42 | NEG_WEIGHT: 1.0 43 | TRAIN_NEG_SAMPLE_SCALE: 0 44 | TRAIN_NEG_SAMPLE_REPLACE: false 45 | VALID_NEG_SAMPLE_SCALE: 0 46 | TEST_NEG_SAMPLE_SCALE: 0 47 | TRAIN: 48 | BATCH_SIZE: 512 49 | GRAPH_SAMPLER_ARGS: [all, 2] 50 | VALID_ITER: 100 51 | TEST_ITER: 100 52 | MAX_ITER: 100000 53 | OPTIMIZER: adam 54 | LR: 0.01 55 | MIN_LR: 0.001 56 | DECAY_PATIENCE: 15 57 | EARLY_STOPPING_PATIENCE: 30 58 | LR_DECAY_FACTOR: 0.5 59 | GRAD_CLIP: -1.0 60 | WD: 0.0 61 | TEST: 62 | BATCH_SIZE: 512 63 | SAMPLE_NUM: 1 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_reddit_sup/215_reddit_sup_d0_mugga_g0_128_32_512_1_64_1_128_32_256_2_64_1_fixed_25_10_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 215 2 | NPY_SEED: 215 3 | DATA_NAME: reddit 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | WEIGHT_ACT: sigmoid 16 | MUGGA: 17 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 18 | USE_EDGE: false 19 | ATTEND_W_DROPOUT: 0.5 20 | CONTEXT: 21 | USE_SUM_POOL: false 22 | USE_MAX_POOL: true 23 | USE_AVG_POOL: true 24 | USE_GATE: false 25 | USE_SHARPNESS: false 26 | STATIC_GRAPH: 27 | MODEL: 28 | TYP: supervised 29 | FEATURE_NORMALIZE: true 30 | FIRST_EMBED_UNITS: 64 31 | AGGREGATOR_ARGS_LIST: 32 | - - MuGGA 33 | - [128, 32, 512, 1, 64, 1] 34 | - - MuGGA 35 | - [128, 32, 256, 2, 64, 1] 36 | DROPOUT_RATE_LIST: [0.1, 0.1] 37 | DENSE_CONNECT: false 38 | L2_NORMALIZATION: false 39 | EVERY_LAYER_L2_NORMALIZATION: false 40 | EMBED_DIM: 64 41 | NEG_WEIGHT: 1.0 42 | TRAIN_NEG_SAMPLE_SCALE: 0 43 | TRAIN_NEG_SAMPLE_REPLACE: false 44 | VALID_NEG_SAMPLE_SCALE: 0 45 | TEST_NEG_SAMPLE_SCALE: 0 46 | TRAIN: 47 | BATCH_SIZE: 512 48 | GRAPH_SAMPLER_ARGS: 49 | - fixed 50 | - [25, 10] 51 | VALID_ITER: 100 52 | TEST_ITER: 100 53 | MAX_ITER: 100000 54 | OPTIMIZER: adam 55 | LR: 0.001 56 | MIN_LR: 0.0001 57 | DECAY_PATIENCE: 4 58 | EARLY_STOPPING_PATIENCE: 10 59 | LR_DECAY_FACTOR: 0.5 60 | GRAD_CLIP: 1.0 61 | WD: 0.0 62 | TEST: 63 | BATCH_SIZE: 512 64 | SAMPLE_NUM: 1 65 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_reddit_sup/215_reddit_sup_d0_mugga_g1_128_32_512_1_64_1_128_32_128_4_64_1_fixed_25_10_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 215 2 | NPY_SEED: 215 3 | DATA_NAME: reddit 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | WEIGHT_ACT: sigmoid 16 | MUGGA: 17 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 18 | USE_EDGE: false 19 | ATTEND_W_DROPOUT: 0.5 20 | CONTEXT: 21 | USE_SUM_POOL: false 22 | USE_MAX_POOL: true 23 | USE_AVG_POOL: true 24 | USE_GATE: true 25 | USE_SHARPNESS: false 26 | STATIC_GRAPH: 27 | MODEL: 28 | TYP: supervised 29 | FEATURE_NORMALIZE: true 30 | FIRST_EMBED_UNITS: 64 31 | AGGREGATOR_ARGS_LIST: 32 | - - MuGGA 33 | - [128, 32, 512, 1, 64, 1] 34 | - - MuGGA 35 | - [128, 32, 128, 4, 64, 1] 36 | DROPOUT_RATE_LIST: [0.1, 0.1] 37 | DENSE_CONNECT: false 38 | L2_NORMALIZATION: false 39 | EVERY_LAYER_L2_NORMALIZATION: false 40 | EMBED_DIM: 64 41 | NEG_WEIGHT: 1.0 42 | TRAIN_NEG_SAMPLE_SCALE: 0 43 | TRAIN_NEG_SAMPLE_REPLACE: false 44 | VALID_NEG_SAMPLE_SCALE: 0 45 | TEST_NEG_SAMPLE_SCALE: 0 46 | TRAIN: 47 | BATCH_SIZE: 512 48 | GRAPH_SAMPLER_ARGS: 49 | - fixed 50 | - [25, 10] 51 | VALID_ITER: 100 52 | TEST_ITER: 100 53 | MAX_ITER: 100000 54 | OPTIMIZER: adam 55 | LR: 0.001 56 | MIN_LR: 0.0001 57 | DECAY_PATIENCE: 4 58 | EARLY_STOPPING_PATIENCE: 10 59 | LR_DECAY_FACTOR: 0.5 60 | GRAD_CLIP: 1.0 61 | WD: 0.0 62 | TEST: 63 | BATCH_SIZE: 512 64 | SAMPLE_NUM: 1 65 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/1206_ppi_20_sup_d0_multi_weighted_div1_tanh_128_64_24_8_128_64_24_8_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: tanh 17 | GRAPH_MULTI_WEIGHTED_SUM: 18 | ARGS: [out_units, mid_units, attend_units, K] 19 | ATTEND_W_DROPOUT: 0.0 20 | DIVIDE_SIZE: true 21 | WEIGHT_ACT: sigmoid 22 | MUGGA: 23 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 24 | USE_EDGE: false 25 | RESCALE_INNERPRODUCT: true 26 | ATTEND_W_DROPOUT: 0.0 27 | CONTEXT: 28 | USE_SUM_POOL: false 29 | USE_MAX_POOL: true 30 | USE_AVG_POOL: true 31 | USE_GATE: true 32 | USE_SHARPNESS: false 33 | STATIC_GRAPH: 34 | MODEL: 35 | TYP: supervised 36 | FEATURE_NORMALIZE: false 37 | FIRST_EMBED_UNITS: 64 38 | AGGREGATOR_ARGS_LIST: 39 | - - GraphMultiWeightedSumAggregator 40 | - [128, 64, 24, 8] 41 | - - GraphMultiWeightedSumAggregator 42 | - [128, 64, 24, 8] 43 | DROPOUT_RATE_LIST: [0.1, 0.1] 44 | DENSE_CONNECT: false 45 | L2_NORMALIZATION: false 46 | EVERY_LAYER_L2_NORMALIZATION: false 47 | EMBED_DIM: 64 48 | NEG_WEIGHT: 1.0 49 | TRAIN_NEG_SAMPLE_SCALE: 0 50 | TRAIN_NEG_SAMPLE_REPLACE: false 51 | VALID_NEG_SAMPLE_SCALE: 0 52 | TEST_NEG_SAMPLE_SCALE: 0 53 | TRAIN: 54 | BATCH_SIZE: 512 55 | GRAPH_SAMPLER_ARGS: [all, 2] 56 | VALID_ITER: 100 57 | TEST_ITER: 100 58 | MAX_ITER: 100000 59 | OPTIMIZER: adam 60 | LR: 0.001 61 | MIN_LR: 0.001 62 | DECAY_PATIENCE: 15 63 | EARLY_STOPPING_PATIENCE: 30 64 | LR_DECAY_FACTOR: 0.5 65 | GRAD_CLIP: -1.0 66 | WD: 0.0 67 | TEST: 68 | BATCH_SIZE: 512 69 | SAMPLE_NUM: 1 70 | -------------------------------------------------------------------------------- /mxgraph/helpers/email_sender.py: -------------------------------------------------------------------------------- 1 | from smtplib import SMTP_SSL as SMTP 2 | import logging 3 | import logging.handlers 4 | import os 5 | import subprocess 6 | import argparse 7 | from email.mime.text import MIMEText 8 | 9 | SRC_ADDRESS = "email_sender123@163.com" 10 | SMTP_SERVER = "smtp.163.com" 11 | PASSWORD = "a12b34c56" 12 | 13 | __PATH__ = os.path.abspath(__file__) 14 | __DIR_NAME = os.path.dirname(__PATH__) 15 | 16 | __LOG_FILE = open(os.path.join(__DIR_NAME, 'email_sender.log'), 'a') 17 | 18 | def _send_msg(title, text, dst_address): 19 | title = 'By Xingjian Email Sender: ' + title 20 | text = 'By Xingjian Email Sender: \n' + text 21 | msg = MIMEText(text, 'plain') 22 | msg['Subject'] = title 23 | msg['From'] = SRC_ADDRESS 24 | msg['To'] = str(dst_address) 25 | try: 26 | conn = SMTP(SMTP_SERVER) 27 | conn.set_debuglevel(True) 28 | conn.login(SRC_ADDRESS, PASSWORD) 29 | try: 30 | conn.sendmail(SRC_ADDRESS, dst_address, msg.as_string()) 31 | finally: 32 | conn.close() 33 | 34 | except Exception as exc: 35 | logging.error("ERROR!!!") 36 | logging.critical(exc) 37 | raise RuntimeError 38 | 39 | 40 | def send_msg(title, text, dst_address): 41 | subprocess.Popen(['python3', __PATH__, 42 | '--title', str(title), 43 | '--text', str(text), 44 | '--dst_address', str(dst_address)], 45 | stdout=__LOG_FILE, 46 | stderr=__LOG_FILE) 47 | 48 | 49 | def parse_args(): 50 | parser = argparse.ArgumentParser(description='Send email given the title and content.') 51 | parser.add_argument('--title', dest='title', required=True, type=str) 52 | parser.add_argument('--text', dest='text', required=True, type=str) 53 | parser.add_argument('--dst_address', dest='dst_address', required=True, type=str) 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | if __name__ == "__main__": 59 | args = parse_args() 60 | _send_msg(title=args.title, text=args.text, dst_address=args.dst_address) 61 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_ppi_sup/1206_ppi_20_sup_d0_multi_weighted_div1_sigmoid_128_64_24_8_128_64_24_8_all_2_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 1206 2 | NPY_SEED: 1206 3 | DATA_NAME: ppi 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | GRAPH_MULTI_WEIGHTED_SUM: 18 | ARGS: [out_units, mid_units, attend_units, K] 19 | ATTEND_W_DROPOUT: 0.0 20 | DIVIDE_SIZE: true 21 | WEIGHT_ACT: sigmoid 22 | MUGGA: 23 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 24 | USE_EDGE: false 25 | RESCALE_INNERPRODUCT: true 26 | ATTEND_W_DROPOUT: 0.0 27 | CONTEXT: 28 | USE_SUM_POOL: false 29 | USE_MAX_POOL: true 30 | USE_AVG_POOL: true 31 | USE_GATE: true 32 | USE_SHARPNESS: false 33 | STATIC_GRAPH: 34 | MODEL: 35 | TYP: supervised 36 | FEATURE_NORMALIZE: false 37 | FIRST_EMBED_UNITS: 64 38 | AGGREGATOR_ARGS_LIST: 39 | - - GraphMultiWeightedSumAggregator 40 | - [128, 64, 24, 8] 41 | - - GraphMultiWeightedSumAggregator 42 | - [128, 64, 24, 8] 43 | DROPOUT_RATE_LIST: [0.1, 0.1] 44 | DENSE_CONNECT: false 45 | L2_NORMALIZATION: false 46 | EVERY_LAYER_L2_NORMALIZATION: false 47 | EMBED_DIM: 64 48 | NEG_WEIGHT: 1.0 49 | TRAIN_NEG_SAMPLE_SCALE: 0 50 | TRAIN_NEG_SAMPLE_REPLACE: false 51 | VALID_NEG_SAMPLE_SCALE: 0 52 | TEST_NEG_SAMPLE_SCALE: 0 53 | TRAIN: 54 | BATCH_SIZE: 512 55 | GRAPH_SAMPLER_ARGS: [all, 2] 56 | VALID_ITER: 100 57 | TEST_ITER: 100 58 | MAX_ITER: 100000 59 | OPTIMIZER: adam 60 | LR: 0.001 61 | MIN_LR: 0.001 62 | DECAY_PATIENCE: 15 63 | EARLY_STOPPING_PATIENCE: 30 64 | LR_DECAY_FACTOR: 0.5 65 | GRAD_CLIP: -1.0 66 | WD: 0.0 67 | TEST: 68 | BATCH_SIZE: 512 69 | SAMPLE_NUM: 1 70 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_reddit_sup/215_reddit_sup_d0_multi_weighted_div1_tanh_128_256_32_4_128_256_32_4_fixed_25_10_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 215 2 | NPY_SEED: 215 3 | DATA_NAME: reddit 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: tanh 17 | GRAPH_MULTI_WEIGHTED_SUM: 18 | ARGS: [out_units, mid_units, attend_units, K] 19 | ATTEND_W_DROPOUT: 0.0 20 | DIVIDE_SIZE: true 21 | WEIGHT_ACT: sigmoid 22 | MUGGA: 23 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 24 | USE_EDGE: false 25 | RESCALE_INNERPRODUCT: true 26 | ATTEND_W_DROPOUT: 0.5 27 | CONTEXT: 28 | USE_SUM_POOL: false 29 | USE_MAX_POOL: true 30 | USE_AVG_POOL: true 31 | USE_GATE: true 32 | USE_SHARPNESS: false 33 | STATIC_GRAPH: 34 | MODEL: 35 | TYP: supervised 36 | FEATURE_NORMALIZE: true 37 | FIRST_EMBED_UNITS: 64 38 | AGGREGATOR_ARGS_LIST: 39 | - - GraphMultiWeightedSumAggregator 40 | - [128, 256, 32, 4] 41 | - - GraphMultiWeightedSumAggregator 42 | - [128, 256, 32, 4] 43 | DROPOUT_RATE_LIST: [0.1, 0.1] 44 | DENSE_CONNECT: false 45 | L2_NORMALIZATION: false 46 | EVERY_LAYER_L2_NORMALIZATION: false 47 | EMBED_DIM: 64 48 | NEG_WEIGHT: 1.0 49 | TRAIN_NEG_SAMPLE_SCALE: 0 50 | TRAIN_NEG_SAMPLE_REPLACE: false 51 | VALID_NEG_SAMPLE_SCALE: 0 52 | TEST_NEG_SAMPLE_SCALE: 0 53 | TRAIN: 54 | BATCH_SIZE: 512 55 | GRAPH_SAMPLER_ARGS: 56 | - fixed 57 | - [25, 10] 58 | VALID_ITER: 100 59 | TEST_ITER: 100 60 | MAX_ITER: 100000 61 | OPTIMIZER: adam 62 | LR: 0.001 63 | MIN_LR: 0.0001 64 | DECAY_PATIENCE: 4 65 | EARLY_STOPPING_PATIENCE: 10 66 | LR_DECAY_FACTOR: 0.5 67 | GRAD_CLIP: 1.0 68 | WD: 0.0 69 | TEST: 70 | BATCH_SIZE: 512 71 | SAMPLE_NUM: 1 72 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_reddit_sup/215_reddit_sup_d0_multi_weighted_div1_sigmoid_128_256_32_4_128_256_32_4_fixed_25_10_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 215 2 | NPY_SEED: 215 3 | DATA_NAME: reddit 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | DIVIDE_SIZE: true 16 | WEIGHT_ACT: sigmoid 17 | GRAPH_MULTI_WEIGHTED_SUM: 18 | ARGS: [out_units, mid_units, attend_units, K] 19 | ATTEND_W_DROPOUT: 0.0 20 | DIVIDE_SIZE: true 21 | WEIGHT_ACT: sigmoid 22 | MUGGA: 23 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 24 | USE_EDGE: false 25 | RESCALE_INNERPRODUCT: true 26 | ATTEND_W_DROPOUT: 0.5 27 | CONTEXT: 28 | USE_SUM_POOL: false 29 | USE_MAX_POOL: true 30 | USE_AVG_POOL: true 31 | USE_GATE: true 32 | USE_SHARPNESS: false 33 | STATIC_GRAPH: 34 | MODEL: 35 | TYP: supervised 36 | FEATURE_NORMALIZE: true 37 | FIRST_EMBED_UNITS: 64 38 | AGGREGATOR_ARGS_LIST: 39 | - - GraphMultiWeightedSumAggregator 40 | - [128, 256, 32, 4] 41 | - - GraphMultiWeightedSumAggregator 42 | - [128, 256, 32, 4] 43 | DROPOUT_RATE_LIST: [0.1, 0.1] 44 | DENSE_CONNECT: false 45 | L2_NORMALIZATION: false 46 | EVERY_LAYER_L2_NORMALIZATION: false 47 | EMBED_DIM: 64 48 | NEG_WEIGHT: 1.0 49 | TRAIN_NEG_SAMPLE_SCALE: 0 50 | TRAIN_NEG_SAMPLE_REPLACE: false 51 | VALID_NEG_SAMPLE_SCALE: 0 52 | TEST_NEG_SAMPLE_SCALE: 0 53 | TRAIN: 54 | BATCH_SIZE: 512 55 | GRAPH_SAMPLER_ARGS: 56 | - fixed 57 | - [25, 10] 58 | VALID_ITER: 100 59 | TEST_ITER: 100 60 | MAX_ITER: 100000 61 | OPTIMIZER: adam 62 | LR: 0.001 63 | MIN_LR: 0.0001 64 | DECAY_PATIENCE: 4 65 | EARLY_STOPPING_PATIENCE: 10 66 | LR_DECAY_FACTOR: 0.5 67 | GRAD_CLIP: 1.0 68 | WD: 0.0 69 | TEST: 70 | BATCH_SIZE: 512 71 | SAMPLE_NUM: 1 72 | -------------------------------------------------------------------------------- /mxgraph/layers/loss.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import nn, HybridBlock 2 | from mxnet.gluon.loss import Loss, LogisticLoss 3 | 4 | 5 | class LinkPredLogisticLoss(Loss): 6 | def __init__(self, neg_weight=1.0, weight=None, **kwargs): 7 | super(LinkPredLogisticLoss, self).__init__(weight=weight, batch_axis=None, **kwargs) 8 | self._neg_weight = neg_weight 9 | 10 | def hybrid_forward(self, F, pred, edge_label, edge_weight=None): 11 | """ 12 | 13 | Parameters 14 | ---------- 15 | F 16 | pred 17 | edge_label: 1-> positive edge; -1 -> negative edge 18 | edge_weight 19 | 20 | Returns 21 | ------- 22 | 23 | """ 24 | binary_label = (edge_label + 1) / 2.0 # 1 -> positive edge; 0 -> negative edge 25 | if edge_weight is None: 26 | edge_weight = binary_label * 1.0 + (1 - binary_label) * self._neg_weight 27 | else: 28 | edge_weight *= binary_label * 1.0 + (1 - binary_label) / 2. * self._neg_weight 29 | #print("edge_weight", edge_weight) 30 | #print("edge_label", edge_label) 31 | loss = edge_weight * F.log(1.0 + F.exp(-pred * edge_label)) 32 | loss = F.sum(loss) 33 | return loss 34 | 35 | 36 | class UnsupWalkLoss(Loss): 37 | def __init__(self, neg_weight=1.0, weight=None, **kwargs): 38 | super(UnsupWalkLoss, self).__init__(weight=weight, batch_axis=None, **kwargs) 39 | self._neg_weight = neg_weight 40 | 41 | def hybrid_forward(self, F, node_emb, pos_emb, neg_emb): 42 | """ 43 | 44 | Parameters 45 | ---------- 46 | F 47 | node_emb : (batch_size, dim) 48 | pos_emb : (batch_size, dim) 49 | neg_emb : (neg_sample_num, dim) 50 | edge_weight 51 | 52 | Returns 53 | ------- 54 | 55 | """ 56 | pos_innerproduct = F.sum(F.broadcast_mul(node_emb, pos_emb), axis=1) ## Shape: (batch_size, ) 57 | neg_innerproduct = F.dot(node_emb, neg_emb, transpose_b=True) ## Shape: (batch_size, num_neg) 58 | 59 | pos_loss = F.log(1.0 + F.exp(-1.0 * pos_innerproduct)) 60 | neg_loss = F.log(1.0 + F.exp(neg_innerproduct)) 61 | loss = F.sum(pos_loss) + self._neg_weight * F.sum(neg_loss) 62 | 63 | return loss 64 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_reddit_sup/215_reddit_sup_d0_pool_avg_128_1024_1_128_1024_1_fixed_25_10_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 215 2 | NPY_SEED: 215 3 | DATA_NAME: reddit 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: avg 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | WEIGHT_ACT: sigmoid 16 | MUGGA: 17 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 18 | USE_EDGE: false 19 | ATTEND_W_DROPOUT: 0.5 20 | CONTEXT: 21 | USE_SUM_POOL: false 22 | USE_MAX_POOL: true 23 | USE_AVG_POOL: true 24 | USE_GATE: true 25 | USE_SHARPNESS: true 26 | STATIC_GRAPH: 27 | MODEL: 28 | TYP: supervised 29 | FEATURE_NORMALIZE: true 30 | FIRST_EMBED_UNITS: 64 31 | AGGREGATOR_ARGS_LIST: 32 | - - GraphPoolAggregator 33 | - [128, 1024, 1] 34 | - - GraphPoolAggregator 35 | - [128, 1024, 1] 36 | DROPOUT_RATE_LIST: [0.1, 0.1] 37 | DENSE_CONNECT: false 38 | L2_NORMALIZATION: false 39 | EVERY_LAYER_L2_NORMALIZATION: false 40 | EMBED_DIM: 64 41 | NEG_WEIGHT: 1.0 42 | TRAIN_NEG_SAMPLE_SCALE: 0 43 | TRAIN_NEG_SAMPLE_REPLACE: false 44 | VALID_NEG_SAMPLE_SCALE: 0 45 | TEST_NEG_SAMPLE_SCALE: 0 46 | TRAIN: 47 | BATCH_SIZE: 512 48 | GRAPH_SAMPLER_ARGS: 49 | - fixed 50 | - [25, 10] 51 | VALID_ITER: 100 52 | TEST_ITER: 100 53 | MAX_ITER: 100000 54 | OPTIMIZER: adam 55 | LR: 0.001 56 | MIN_LR: 0.0001 57 | DECAY_PATIENCE: 4 58 | EARLY_STOPPING_PATIENCE: 10 59 | LR_DECAY_FACTOR: 0.5 60 | GRAD_CLIP: 1.0 61 | WD: 0.0 62 | TEST: 63 | BATCH_SIZE: 512 64 | SAMPLE_NUM: 1 65 | SPATIOTEMPORAL_GRAPH: 66 | IN_LENGTH: 12 67 | OUT_LENGTH: 12 68 | MODEL: 69 | AGGREGATOR_ARGS: 70 | - MuGGA 71 | - [32, 8, 16, 4, 32, 1] 72 | DROPOUT_RATE: 0.0 73 | LAYER_NUM: 2 74 | USE_EDGE: false 75 | TRAIN: 76 | MAX_EPOCH: 20 77 | OPTIMIZER: adam 78 | LR: 0.001 79 | MIN_LR: 0.0001 80 | DECAY_PATIENCE: 15 81 | EARLY_STOPPING_PATIENCE: 30 82 | LR_DECAY_FACTOR: 0.5 83 | GRAD_CLIP: 10.0 84 | WD: 0.0 85 | -------------------------------------------------------------------------------- /experiments/static_graph/cfg_reddit_sup/215_reddit_sup_d0_pool_max_128_1024_1_128_1024_1_fixed_25_10_d0.10.1.yml: -------------------------------------------------------------------------------- 1 | MX_SEED: 215 2 | NPY_SEED: 215 3 | DATA_NAME: reddit 4 | SPLIT_TRAINING: false 5 | TRAIN_SPLIT_NUM: 20 6 | LOAD_WALKS: false 7 | AGGREGATOR: 8 | ACTIVATION: leaky 9 | GRAPHPOOL: 10 | ARGS: [out_units, mid_units, mid_layer_num] 11 | POOL_TYPE: max 12 | GRAPH_WEIGHTED_SUM: 13 | ARGS: [out_units, mid_units, attend_units] 14 | ATTEND_W_DROPOUT: 0.0 15 | WEIGHT_ACT: sigmoid 16 | MUGGA: 17 | ARGS: [out_units, attend_units, value_units, K, context_units, context_layer_num] 18 | USE_EDGE: false 19 | ATTEND_W_DROPOUT: 0.5 20 | CONTEXT: 21 | USE_SUM_POOL: false 22 | USE_MAX_POOL: true 23 | USE_AVG_POOL: true 24 | USE_GATE: true 25 | USE_SHARPNESS: true 26 | STATIC_GRAPH: 27 | MODEL: 28 | TYP: supervised 29 | FEATURE_NORMALIZE: true 30 | FIRST_EMBED_UNITS: 64 31 | AGGREGATOR_ARGS_LIST: 32 | - - GraphPoolAggregator 33 | - [128, 1024, 1] 34 | - - GraphPoolAggregator 35 | - [128, 1024, 1] 36 | DROPOUT_RATE_LIST: [0.1, 0.1] 37 | DENSE_CONNECT: false 38 | L2_NORMALIZATION: false 39 | EVERY_LAYER_L2_NORMALIZATION: false 40 | EMBED_DIM: 64 41 | NEG_WEIGHT: 1.0 42 | TRAIN_NEG_SAMPLE_SCALE: 0 43 | TRAIN_NEG_SAMPLE_REPLACE: false 44 | VALID_NEG_SAMPLE_SCALE: 0 45 | TEST_NEG_SAMPLE_SCALE: 0 46 | TRAIN: 47 | BATCH_SIZE: 512 48 | GRAPH_SAMPLER_ARGS: 49 | - fixed 50 | - [25, 10] 51 | VALID_ITER: 100 52 | TEST_ITER: 100 53 | MAX_ITER: 100000 54 | OPTIMIZER: adam 55 | LR: 0.001 56 | MIN_LR: 0.0001 57 | DECAY_PATIENCE: 4 58 | EARLY_STOPPING_PATIENCE: 10 59 | LR_DECAY_FACTOR: 0.5 60 | GRAD_CLIP: 1.0 61 | WD: 0.0 62 | TEST: 63 | BATCH_SIZE: 512 64 | SAMPLE_NUM: 1 65 | SPATIOTEMPORAL_GRAPH: 66 | IN_LENGTH: 12 67 | OUT_LENGTH: 12 68 | MODEL: 69 | AGGREGATOR_ARGS: 70 | - MuGGA 71 | - [32, 8, 16, 4, 32, 1] 72 | DROPOUT_RATE: 0.0 73 | LAYER_NUM: 2 74 | USE_EDGE: false 75 | TRAIN: 76 | MAX_EPOCH: 20 77 | OPTIMIZER: adam 78 | LR: 0.001 79 | MIN_LR: 0.0001 80 | DECAY_PATIENCE: 15 81 | EARLY_STOPPING_PATIENCE: 30 82 | LR_DECAY_FACTOR: 0.5 83 | GRAD_CLIP: 10.0 84 | WD: 0.0 85 | -------------------------------------------------------------------------------- /GraphSampler/cmake/Modules/FindNumpy.cmake: -------------------------------------------------------------------------------- 1 | # - Find the NumPy libraries 2 | # This module finds if NumPy is installed, and sets the following variables 3 | # indicating where it is. 4 | # 5 | # TODO: Update to provide the libraries and paths for linking npymath lib. 6 | # 7 | # NUMPY_FOUND - was NumPy found 8 | # NUMPY_VERSION - the version of NumPy found as a string 9 | # NUMPY_VERSION_MAJOR - the major version number of NumPy 10 | # NUMPY_VERSION_MINOR - the minor version number of NumPy 11 | # NUMPY_VERSION_PATCH - the patch version number of NumPy 12 | # NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601 13 | # NUMPY_INCLUDE_DIR - path to the NumPy include files 14 | 15 | unset(NUMPY_VERSION) 16 | unset(NUMPY_INCLUDE_DIR) 17 | 18 | if(PYTHONINTERP_FOUND) 19 | execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" 20 | "import numpy as n; print(n.__version__); print(n.get_include());" 21 | RESULT_VARIABLE __result 22 | OUTPUT_VARIABLE __output 23 | OUTPUT_STRIP_TRAILING_WHITESPACE) 24 | 25 | if(__result MATCHES 0) 26 | string(REGEX REPLACE ";" "\\\\;" __values ${__output}) 27 | string(REGEX REPLACE "\r?\n" ";" __values ${__values}) 28 | list(GET __values 0 NUMPY_VERSION) 29 | list(GET __values 1 NUMPY_INCLUDE_DIR) 30 | 31 | string(REGEX MATCH "^([0-9])+\\.([0-9])+\\.([0-9])+" __ver_check "${NUMPY_VERSION}") 32 | if(NOT "${__ver_check}" STREQUAL "") 33 | set(NUMPY_VERSION_MAJOR ${CMAKE_MATCH_1}) 34 | set(NUMPY_VERSION_MINOR ${CMAKE_MATCH_2}) 35 | set(NUMPY_VERSION_PATCH ${CMAKE_MATCH_3}) 36 | math(EXPR NUMPY_VERSION_DECIMAL 37 | "(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}") 38 | string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIR ${NUMPY_INCLUDE_DIR}) 39 | else() 40 | unset(NUMPY_VERSION) 41 | unset(NUMPY_INCLUDE_DIR) 42 | message(STATUS "Requested NumPy version and include path, but got instead:\n${__output}\n") 43 | endif() 44 | endif() 45 | else() 46 | message(STATUS "To find NumPy Python interpretator is required to be found.") 47 | endif() 48 | 49 | include(FindPackageHandleStandardArgs) 50 | find_package_handle_standard_args(NumPy REQUIRED_VARS NUMPY_INCLUDE_DIR NUMPY_VERSION 51 | VERSION_VAR NUMPY_VERSION) 52 | 53 | if(NUMPY_FOUND) 54 | message(STATUS "NumPy ver. ${NUMPY_VERSION} found (include: ${NUMPY_INCLUDE_DIR})") 55 | endif() 56 | -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | from urllib import request 2 | import argparse 3 | import ssl 4 | import os 5 | 6 | # HTTP_PROXY = 'http://dev-proxy.oa.com:8080' 7 | # HTTPS_PROXY = 'https://dev-proxy.oa.com:8080' 8 | # os.environ['http_proxy'] = HTTP_PROXY 9 | # os.environ['HTTP_PROXY'] = HTTP_PROXY 10 | # os.environ['https_proxy'] = HTTPS_PROXY 11 | # os.environ['HTTPS_PROXY'] = HTTPS_PROXY 12 | 13 | if not os.path.exists("datasets"): 14 | os.mkdir('datasets') 15 | if not os.path.exists(os.path.join("datasets", "ppi")): 16 | os.mkdir('datasets/ppi') 17 | if not os.path.exists(os.path.join("datasets", "reddit")): 18 | os.mkdir('datasets/reddit') 19 | if not os.path.exists(os.path.join("datasets", "cora")): 20 | os.mkdir('datasets/cora') 21 | 22 | ssl._create_default_https_context = ssl._create_unverified_context 23 | 24 | 25 | download_cora = [os.path.join('datasets', 'cora.zip'), 26 | 'https://www.dropbox.com/sh/q1ms4e1qgbml6lj/AADO4HoFj5Y76NNoQNNB45Sga?dl=1'] 27 | download_ppi = [os.path.join('datasets', 'ppi.zip'), 28 | 'https://www.dropbox.com/sh/brmvu4dnjced6rb/AAArmBA5O_JMIZShlNftqj5Ca?dl=1'] 29 | download_reddit = [os.path.join('datasets', 'reddit.zip'), 30 | 'https://www.dropbox.com/sh/jbodwifw54za0dm/AAAzFV2pDzbGSduvMXqUhPhZa?dl=1'] 31 | download_traffic = [[os.path.join('datasets', 'traffic_LA', 'traffic_data.h5'), 32 | 'https://www.dropbox.com/s/7fbafmbiyjb96n4/df_highway_2012_4mon_sample.h5?dl=1'], 33 | [os.path.join('datasets', 'traffic_SF', 'traffic_data.h5'), 34 | 'https://www.dropbox.com/s/nf6uj5zbfhepgyh/df_highway_2017_6mon_sf.h5?dl=1']] 35 | ### temporary use 36 | download_ppi_Graphsage = [os.path.join('datasets', 'ppi_Graphsage.zip'), 37 | "https://www.dropbox.com/sh/bw5t70e85no6cae/AAAYNb15UjOv_sjxgCz040PQa?dl=1"] 38 | download_reddit_Graphsage = [os.path.join('datasets', 'reddit_Graphsage.zip'), 39 | "https://www.dropbox.com/sh/fa79kg10dwn3w40/AADStZCMmVWkwf3TaHeqDUg5a?dl=1"] 40 | 41 | 42 | parser = argparse.ArgumentParser(description='Downloading the necessary data') 43 | parser.add_argument('--overwrite', dest='overwrite', action='store_true', 44 | help='Whether to overwrite the stored data files') 45 | parser.add_argument('--dataset', type=str, default='cora', help='the dataset name you want to download') 46 | 47 | args = parser.parse_args() 48 | download_jobs = [] 49 | #if args.dataset == "cora": 50 | # download_jobs.append(download_cora) 51 | if args.dataset == "ppi": 52 | download_jobs.append(download_ppi) 53 | elif args.dataset == "reddit": 54 | download_jobs.append(download_reddit) 55 | #elif args.dataset == "traffic": 56 | # download_jobs.extend(download_traffic) 57 | elif args.dataset == "all": 58 | # download_jobs.append(download_cora) 59 | download_jobs.append(download_ppi) 60 | download_jobs.append(download_reddit) 61 | # download_jobs.extend(download_traffic) 62 | 63 | for target_path, src_path in download_jobs: 64 | if not os.path.exists(target_path) or args.overwrite: 65 | print('Downloading from %s to %s...' % (src_path, target_path)) 66 | data_file = request.urlopen(src_path) 67 | with open(target_path, 'wb') as output: 68 | output.write(data_file.read()) 69 | print('Done!') 70 | else: 71 | print('Found %s' % target_path) 72 | 73 | def unzip_dataset(data_name): 74 | if data_name == "cora": 75 | subprocess.call(["unzip", "datasets/cora.zip", "-d", "datasets/cora"]) 76 | subprocess.call(["rm", "datasets/cora.zip"]) 77 | print("Downloaded the cora dataset!\n") 78 | elif data_name == "ppi": 79 | subprocess.call(["unzip", "datasets/ppi.zip", "-d", "datasets/ppi"]) 80 | subprocess.call(["rm", "datasets/ppi.zip"]) 81 | print("Downloaded the ppi dataset!\n") 82 | elif data_name == "reddit": 83 | subprocess.call(["unzip", "datasets/reddit.zip", "-d", "datasets/reddit"]) 84 | subprocess.call(["rm", "datasets/reddit.zip"]) 85 | print("Downloaded the reddit dataset!\n") 86 | 87 | import subprocess 88 | 89 | #if args.dataset == "cora" or args.dataset == "all": 90 | # unzip_dataset("cora") 91 | if args.dataset == "ppi" or args.dataset == "all": 92 | unzip_dataset("ppi") 93 | if args.dataset == "reddit" or args.dataset == "all": 94 | unzip_dataset("reddit") 95 | 96 | subprocess.call(["rm", "-fr", "datasets/__MACOSX"]) 97 | -------------------------------------------------------------------------------- /mxgraph/layers/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mxnet as mx 3 | import numpy as np 4 | import mxnet.ndarray as nd 5 | import mxnet.gluon as gluon 6 | from mxnet.gluon import nn, HybridBlock, Block 7 | 8 | 9 | class IdentityActivation(HybridBlock): 10 | def hybrid_forward(self, F, x): 11 | return x 12 | 13 | 14 | class ELU(HybridBlock): 15 | r""" 16 | Exponential Linear Unit (ELU) 17 | "Fast and Accurate Deep Network Learning by Exponential Linear Units", Clevert et al, 2016 18 | https://arxiv.org/abs/1511.07289 19 | Published as a conference paper at ICLR 2016 20 | Parameters 21 | ---------- 22 | alpha : float 23 | The alpha parameter as described by Clevert et al, 2016 24 | Inputs: 25 | - **data**: input tensor with arbitrary shape. 26 | Outputs: 27 | - **out**: output tensor with the same shape as `data`. 28 | """ 29 | def __init__(self, alpha=1.0, **kwargs): 30 | super(ELU, self).__init__(**kwargs) 31 | self._alpha = alpha 32 | 33 | def hybrid_forward(self, F, x): 34 | return - self._alpha * F.relu(1.0 - F.exp(x)) + F.relu(x) 35 | 36 | 37 | def get_activation(act): 38 | """Get the activation based on the act string 39 | 40 | Parameters 41 | ---------- 42 | act: str or HybridBlock 43 | 44 | Returns 45 | ------- 46 | ret: HybridBlock 47 | """ 48 | if isinstance(act, str): 49 | if act == 'leaky': 50 | return nn.LeakyReLU(0.1) 51 | elif act == 'identity': 52 | return IdentityActivation() 53 | elif act == 'elu': 54 | return ELU() 55 | else: 56 | return nn.Activation(act) 57 | else: 58 | return act 59 | 60 | 61 | class DenseNetBlock(HybridBlock): 62 | def __init__(self, units, layer_num, act, flatten=False, prefix=None, params=None): 63 | super(DenseNetBlock, self).__init__(prefix=prefix, params=params) 64 | self._units = units 65 | self._layer_num = layer_num 66 | self._act = get_activation(act) 67 | print("layer_num", layer_num) 68 | with self.name_scope(): 69 | self.layers = nn.HybridSequential('dblock_') 70 | with self.layers.name_scope(): 71 | for _ in range(layer_num): 72 | self.layers.add(nn.Dense(units, flatten=flatten)) 73 | 74 | def hybrid_forward(self, F, x): 75 | layer_in_l = [x] 76 | layer_out = None 77 | for i in range(self._layer_num): 78 | if len(layer_in_l) == 1: 79 | layer_in = layer_in_l[0] 80 | else: 81 | layer_in = F.concat(*layer_in_l, dim=-1) 82 | layer_out = self._act(self.layers[i](layer_in)) 83 | layer_in_l.append(layer_out) 84 | return layer_out 85 | 86 | class HeterDenseNetBlock(Block): 87 | def __init__(self, units, layer_num, act, num_set, flatten=False, prefix=None, params=None): 88 | super(HeterDenseNetBlock, self).__init__(prefix=prefix, params=params) 89 | self._units = units 90 | self._num_set = num_set 91 | self._layer_num = layer_num 92 | self._act = get_activation(act) 93 | with self.name_scope(): 94 | self.layers = nn.Sequential('hdblock_') 95 | with self.layers.name_scope(): 96 | for _ in range(layer_num): 97 | self.layers.add(nn.Dense(units*num_set, flatten=flatten)) 98 | 99 | def forward(self, x, mask): 100 | """ 101 | 102 | Parameters 103 | ---------- 104 | F 105 | x: Shape(batch_size, num_node, input_dim) 106 | mask: Shape(batch_size, num_node, num_set, 1) 107 | 108 | Returns 109 | ------- 110 | 111 | """ 112 | layer_in_l = [x] 113 | layer_out = None 114 | for i in range(self._layer_num): 115 | if len(layer_in_l) == 1: 116 | layer_in = layer_in_l[0] 117 | else: 118 | layer_in = nd.concat(*layer_in_l, dim=-1) 119 | ### TODO assume batch_size=1 120 | x_mW = nd.reshape(self.layers[i](layer_in), shape=(0, 0, self._num_set, self._units)) 121 | layer_out = self._act( nd.sum(nd.broadcast_mul(x_mW, mask), axis=-2) ) 122 | layer_in_l.append(layer_out) 123 | return layer_out 124 | 125 | 126 | class L2Normalization(HybridBlock): 127 | def __init__(self, axis=-1, eps=1E-6, prefix=None, params=None): 128 | super(L2Normalization, self).__init__(prefix=prefix, params=params) 129 | self._axis = axis 130 | self._eps = eps 131 | 132 | def hybrid_forward(self, F, x): 133 | ret = F.broadcast_div(x, F.sqrt(F.sum(F.square(x), axis=self._axis, keepdims=True) 134 | + self._eps)) 135 | return ret 136 | -------------------------------------------------------------------------------- /seg_ops_cuda/README.md: -------------------------------------------------------------------------------- 1 | Implementation of the Segmented Operators 2 | ----------------------------------------- 3 | To use the following operators, copy the `seg_op.*` files in [mxnet_op](mxnet_op) to `incubator-mxnet/src/operator/contrib`. 4 | 5 | - [seg_sum](#seg_sum) 6 | - [seg_broadcast_add](#seg_broadcast_add) 7 | - [seg_broadcast_mul](#seg_broadcast_mul) 8 | - [seg_broadcast_to](#seg_broadcast_to) 9 | - [seg_softmax](#seg_softmax) 10 | - [seg_take_k_corr](#seg_take_k_corr) 11 | - [seg_weighted_pool](#seg_weighted_pool) 12 | - [seg_pool](#seg_pool) 13 | 14 | After copying and recompiling MXNet, test the operators by running 15 | 16 | ```python 17 | python mxnet_op/test_seg_ops.py 18 | ``` 19 | 20 | ## seg_sum 21 | 22 | Reduce the last dimension of the input based on the given segment indicators. 23 | 24 | Inputs: 25 | - data: Shape (batch_num, nnz) 26 | - indptr: Shape (seg_num + 1,) 27 | 28 | Outputs: 29 | - ret: Shape (batch_num, seg_num) 30 | ```c++ 31 | for k = 0 to batch_num - 1 32 | for i = 0 to seg_num - 1 33 | ret[k, i] = reduce(data[k, indptr[i]], ..., data[k, indptr[i + 1] - 1]) 34 | ``` 35 | Examples:: 36 | 37 | out = seg_sum(data=data, indptr=indptr) 38 | 39 | ## seg_broadcast_add 40 | 41 | Broadcast rhs according to the segment indicators and add to lhs to get the result. 42 | 43 | Inputs: 44 | - lhs: Shape (batch_num, nnz) 45 | - rhs: Shape (batch_num, seg_num) 46 | - indptr: Shape (seg_num + 1,) 47 | 48 | Outputs: 49 | - ret: Shape (batch_num, nnz) 50 | 51 | Examples:: 52 | 53 | ret = seg_broadcast_add(lhs=lhs, rhs=rhs, indptr=indptr) 54 | 55 | ## seg_broadcast_mul 56 | 57 | Broadcast rhs according to the segment indicators and mul to lhs to get the result. 58 | 59 | Inputs: 60 | - lhs: Shape (batch_num, nnz) 61 | - rhs: Shape (batch_num, seg_num) 62 | - indptr: Shape (seg_num + 1,) 63 | 64 | Outputs: 65 | - ret: Shape (batch_num, nnz) 66 | 67 | Examples:: 68 | 69 | ret = seg_broadcast_mul(lhs=lhs, rhs=rhs, indptr=indptr) 70 | 71 | ## seg_broadcast_to 72 | 73 | Broadcast rhs according to the segment indicators and add to lhs to get the result. 74 | 75 | Inputs: 76 | - data: Shape (batch_num, seg_num) 77 | - indptr: Shape (seg_num + 1,) 78 | - int nnz 79 | 80 | Outputs: 81 | - ret: Shape (batch_num, nnz) 82 | 83 | Examples:: 84 | 85 | ret = seg_broadcast_to(data=data, indptr=indptr, nnz=nnz) 86 | 87 | ## seg_softmax 88 | 89 | Calculate the softmax of the the input based on the given segment indicators. 90 | 91 | Inputs: 92 | - data: Shape (batch_num, nnz) 93 | - indptr: Shape (seg_num + 1,) 94 | 95 | Outputs: 96 | - ret: Shape (batch_num, nnz) 97 | 98 | ```c++ 99 | for k = 0 to batch_num - 1 100 | for i = 0 to seg_num - 1 101 | ret[k, indptr[i]:indptr[i+1]] = softmax(data[k, indptr[i]:indptr[i+1]]) 102 | ``` 103 | 104 | Examples:: 105 | 106 | out = seg_softmax(data=data, indptr=indptr) 107 | 108 | ## seg_take_k_corr 109 | 110 | For all the nodes, computes the inner product between the node and it's neighborhoods and add to dst. 111 | We assume the node_ids are 0, 1, 2, ..., node_num - 1 112 | 113 | Inputs: 114 | - embed1: Shape (K, node_num, feat_dim) 115 | - embed2: Shape (K, neighbor_node_num, feat_dim) 116 | - neighbor_ids: Shape (nnz, ) 117 | - neighbor_indptr: Shape(node_num + 1, ) 118 | 119 | Outputs: 120 | - dst: Shape (K, nnz) 121 | 122 | ```c++ 123 | for k = 0 to K-1 124 | for i = 0 to node_num - 1 125 | for j = ind_ptr[i] to ind_ptr[i+1] - 1 126 | neighbor_id = neighbor_ids[j] 127 | dst[k, j] += InnerProduct(embed1[k, i], embed2[k, neighbor_id]) 128 | ``` 129 | 130 | Examples:: 131 | 132 | out = seg_take_k_corr(embed1=embed1, embed2=embed2, neighbor_ids=neighbor_ids, neighbor_indptr=neighbor_indptr) 133 | 134 | ## seg_weighted_pool 135 | 136 | Compute weighted average of values in the segments 137 | 138 | Inputs: 139 | - data: Shape (batch_size, total_ind_num, feat_dim) 140 | - weights: Shape (batch_size, nnz) 141 | - indices: Shape (nnz, ) 142 | - indptr: Shape (seg_num + 1,) 143 | 144 | Outputs: 145 | - dst: Shape (batch_size, seg_num, feat_dim) 146 | 147 | ```c++ 148 | for k = 0 to K-1 149 | for i = 0 to node_num - 1 150 | for j = ind_ptr[i] to ind_ptr[i+1] - 1 151 | dst[k, i, :] += weights[k, j] * data[k, neighbor_ids[j], :] 152 | ``` 153 | 154 | Examples:: 155 | 156 | out = seg_weighted_pool(data=data, weights=weights, indices=indices, indptr=indptr) 157 | 158 | ## seg_pool 159 | 160 | Pooling of the values in the segments 161 | 162 | Inputs: 163 | - data : Shape (batch_size, total_ind_num, feat_dim) 164 | - indices : Shape (nnz,) 165 | - indptr : Shape (seg_num + 1,) 166 | - pool_type : 'avg' or 'sum' or 'max' 167 | 168 | Outputs: 169 | - dst : Shape (batch_size, seg_num, feat_dim) 170 | 171 | Examples:: 172 | 173 | out = seg_pool(data=data, 174 | indices=indices, 175 | indptr=indptr, 176 | pool_type='avg') -------------------------------------------------------------------------------- /mxgraph/cfg_helper.py: -------------------------------------------------------------------------------- 1 | def generate_file_name_static(local_cfg, model=None): 2 | """ 3 | 4 | Parameters 5 | ---------- 6 | local_cfg : OrderedEdict 7 | model : str 8 | 9 | Returns 10 | ------- 11 | file_name 12 | """ 13 | file_name = str(local_cfg.MX_SEED) + "_" + local_cfg.DATA_NAME 14 | if local_cfg.DATA_NAME == 'ppi': 15 | file_name += "_" + str(local_cfg.TRAIN_SPLIT_NUM) 16 | if model is None: 17 | if local_cfg.STATIC_GRAPH.MODEL.AGGREGATOR_ARGS_LIST[0][0].lower() == "mugga": 18 | model = "mugga" 19 | elif local_cfg.STATIC_GRAPH.MODEL.AGGREGATOR_ARGS_LIST[0][0].lower() == "GraphPoolAggregator".lower(): 20 | model = "pool" 21 | elif local_cfg.STATIC_GRAPH.MODEL.AGGREGATOR_ARGS_LIST[0][0].lower() == "GraphWeightedSumAggregator".lower(): 22 | model = "weighted" 23 | elif local_cfg.STATIC_GRAPH.MODEL.AGGREGATOR_ARGS_LIST[0][0].lower() == "GraphMultiWeightedSumAggregator".lower(): 24 | model = "multi_weighted" 25 | else: 26 | raise NotImplementedError() 27 | if local_cfg.STATIC_GRAPH.MODEL.TYP == "supervised": 28 | file_name += "_sup" 29 | elif local_cfg.STATIC_GRAPH.MODEL.TYP == "unsupervised": 30 | file_name += "_unsup" 31 | file_name += "_neg%d+%d+%d_%g" % (local_cfg.STATIC_GRAPH.MODEL.TRAIN_NEG_SAMPLE_SCALE, 32 | local_cfg.STATIC_GRAPH.MODEL.VALID_NEG_SAMPLE_SCALE, 33 | local_cfg.STATIC_GRAPH.MODEL.TEST_NEG_SAMPLE_SCALE, 34 | local_cfg.STATIC_GRAPH.MODEL.NEG_WEIGHT) 35 | file_name += "_emb%d" % local_cfg.STATIC_GRAPH.MODEL.EMBED_DIM 36 | elif local_cfg.STATIC_GRAPH.MODEL.TYP == "transductive": 37 | file_name += "_trans" 38 | file_name += "_d" + str(int(local_cfg.STATIC_GRAPH.MODEL.DENSE_CONNECT)) 39 | file_name += "_" + model 40 | if model == 'mugga': 41 | # file_name += "_sp" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_SUM_POOL)) 42 | # file_name += "_mp" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_MAX_POOL)) 43 | # file_name += "_ap" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_AVG_POOL)) 44 | file_name += "_g" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_GATE)) 45 | file_name += "_s" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_SHARPNESS)) 46 | elif model == 'pool': 47 | file_name += "_" + local_cfg.AGGREGATOR.GRAPHPOOL.POOL_TYPE 48 | elif model == 'weighted': 49 | file_name += "_div" + str(int(local_cfg.AGGREGATOR.GRAPH_WEIGHTED_SUM.DIVIDE_SIZE)) 50 | file_name += "_" + local_cfg.AGGREGATOR.GRAPH_WEIGHTED_SUM.WEIGHT_ACT 51 | elif model == 'multi_weighted': 52 | file_name += "_div" + str(int(local_cfg.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM.DIVIDE_SIZE)) 53 | file_name += "_" + local_cfg.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM.WEIGHT_ACT 54 | else: 55 | raise NotImplementedError 56 | for layer_info in local_cfg.STATIC_GRAPH.MODEL.AGGREGATOR_ARGS_LIST: 57 | for units in layer_info[1]: 58 | file_name += '_' + str(units) 59 | file_name += '_' + local_cfg.STATIC_GRAPH.MODEL.TRAIN.GRAPH_SAMPLER_ARGS[0] 60 | if isinstance(local_cfg.STATIC_GRAPH.MODEL.TRAIN.GRAPH_SAMPLER_ARGS[1], int): 61 | file_name += '_' + str(local_cfg.STATIC_GRAPH.MODEL.TRAIN.GRAPH_SAMPLER_ARGS[1]) 62 | else: 63 | for ele in local_cfg.STATIC_GRAPH.MODEL.TRAIN.GRAPH_SAMPLER_ARGS[1]: 64 | file_name += '_' + str(ele) 65 | file_name += '_d' 66 | for dropout in local_cfg.STATIC_GRAPH.MODEL.DROPOUT_RATE_LIST: 67 | file_name += '%g' % (dropout) 68 | if local_cfg.STATIC_GRAPH.MODEL.L2_NORMALIZATION: 69 | file_name += '_norm' 70 | return file_name 71 | 72 | 73 | def generate_file_name_spatiotemporal(local_cfg): 74 | """ 75 | 76 | Parameters 77 | ---------- 78 | local_cfg 79 | model 80 | 81 | Returns 82 | ------- 83 | file_name 84 | """ 85 | file_name = str(local_cfg.MX_SEED) + "_" + local_cfg.DATA_NAME 86 | file_name += "_" + local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.RNN_TYPE 87 | file_name += "_" + local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.AGGREGATION_TYPE 88 | first_layer_args = local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.AGGREGATOR_ARGS_LIST[0] 89 | if first_layer_args[0].lower() == "mugga": 90 | model = "mugga" 91 | elif first_layer_args[0].lower()\ 92 | == "GraphPoolAggregator".lower(): 93 | model = "pool" 94 | elif first_layer_args[0].lower()\ 95 | == "GraphWeightedSumAggregator".lower(): 96 | model = "weighted" 97 | elif first_layer_args[0].lower()\ 98 | == "GraphMultiWeightedSumAggregator".lower(): 99 | model = "multi_weighted" 100 | else: 101 | raise NotImplementedError() 102 | file_name += "_" + model 103 | if model == 'mugga': 104 | # file_name += "_sp" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_SUM_POOL)) 105 | # file_name += "_mp" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_MAX_POOL)) 106 | # file_name += "_ap" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_AVG_POOL)) 107 | file_name += "_g" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_GATE)) 108 | file_name += "_s" + str(int(local_cfg.AGGREGATOR.MUGGA.CONTEXT.USE_SHARPNESS)) 109 | elif model == 'pool': 110 | file_name += "_" + local_cfg.AGGREGATOR.GRAPHPOOL.POOL_TYPE 111 | elif model == 'weighted': 112 | file_name += "_" + local_cfg.AGGREGATOR.GRAPH_WEIGHTED_SUM.WEIGHT_ACT 113 | elif model == 'multi_weighted': 114 | file_name += "_" + local_cfg.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM.WEIGHT_ACT 115 | else: 116 | raise NotImplementedError 117 | state_dim = first_layer_args[1][0] 118 | file_name += "_s" + str(state_dim) 119 | for i, aggregator_args in enumerate(local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.AGGREGATOR_ARGS_LIST): 120 | file_name += "_l%d" %i 121 | for ele in aggregator_args[1]: 122 | file_name += "_" + str(ele) 123 | file_name += '_tau%d_%d_%d' %(local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.SCHEDULED_SAMPLING.TAU, 124 | local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.INITIAL_EPOCHS, 125 | local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.DECAY_PATIENCE) 126 | file_name += '_edge%d' % int(local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.USE_EDGE) 127 | file_name += '_d%g' % local_cfg.SPATIOTEMPORAL_GRAPH.MODEL.DROPOUT_RATE 128 | return file_name 129 | -------------------------------------------------------------------------------- /mxgraph/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import inspect 4 | import logging 5 | import re 6 | try: 7 | import mxnet.ndarray as nd 8 | import numpy as np 9 | except: 10 | import numpy as np 11 | from sklearn.metrics import accuracy_score, f1_score 12 | from sklearn.multioutput import MultiOutputClassifier 13 | from sklearn.linear_model import SGDClassifier 14 | 15 | def safe_eval(expr): 16 | if type(expr) is str: 17 | return ast.literal_eval(expr) 18 | else: 19 | return expr 20 | 21 | 22 | def get_name_id(dir_path): 23 | name_id = 0 24 | file_path = os.path.join(dir_path, 'cfg%d.yml' % name_id) 25 | while os.path.exists(file_path): 26 | name_id += 1 27 | file_path = os.path.join(dir_path, 'cfg%d.yml' % name_id) 28 | return name_id 29 | 30 | 31 | def logging_config(folder=None, name=None, 32 | level=logging.DEBUG, 33 | console_level=logging.DEBUG, 34 | no_console=True): 35 | """ 36 | 37 | Parameters 38 | ---------- 39 | folder : str or None 40 | name : str or None 41 | level : int 42 | console_level 43 | no_console: bool 44 | Whether to disable the console log 45 | 46 | Returns 47 | ------- 48 | 49 | """ 50 | if name is None: 51 | name = inspect.stack()[1][1].split('.')[0] 52 | if folder is None: 53 | folder = os.path.join(os.getcwd(), name) 54 | if not os.path.exists(folder): 55 | os.makedirs(folder) 56 | # Remove all the current handlers 57 | for handler in logging.root.handlers: 58 | logging.root.removeHandler(handler) 59 | logging.root.handlers = [] 60 | logpath = os.path.join(folder, name + ".log") 61 | print("All Logs will be saved to %s" %logpath) 62 | logging.root.setLevel(level) 63 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 64 | logfile = logging.FileHandler(logpath) 65 | logfile.setLevel(level) 66 | logfile.setFormatter(formatter) 67 | logging.root.addHandler(logfile) 68 | if not no_console: 69 | # Initialze the console logging 70 | logconsole = logging.StreamHandler() 71 | logconsole.setLevel(console_level) 72 | logconsole.setFormatter(formatter) 73 | logging.root.addHandler(logconsole) 74 | return folder 75 | 76 | 77 | def parse_ctx(ctx_args): 78 | import mxnet as mx 79 | ctx = re.findall('([a-z]+)(\d*)', ctx_args) 80 | ctx = [(device, int(num)) if len(num) > 0 else (device, 0) for device, num in ctx] 81 | ctx = [mx.Context(*ele) for ele in ctx] 82 | return ctx 83 | 84 | 85 | def gluon_total_param_num(net): 86 | return sum([np.prod(v.shape) for v in net.collect_params().values()]) 87 | 88 | 89 | def gluon_log_net_info(net, save_path): 90 | f = open(save_path, 'w') 91 | f.write('Total Param Number: %d\n' %gluon_total_param_num(net)) 92 | f.write('Params:') 93 | for k, v in net.collect_params().items(): 94 | f.write('\t%s: %s, %d\n' %(k, str(v.shape), np.prod(v.shape))) 95 | f.write(str(net)) 96 | f.close() 97 | 98 | 99 | def get_global_norm(arrays): 100 | ctx = arrays[0].context 101 | total_norm = nd.add_n(*[nd.dot(x, x).as_in_context(ctx) 102 | for x in (arr.reshape((-1,)) for arr in arrays)]) 103 | total_norm = nd.sqrt(total_norm).asscalar() 104 | return total_norm 105 | 106 | 107 | def div_up(a, b): 108 | return (a + b - 1) // b 109 | 110 | 111 | def copy_to_ctx(data, ctx, dtype=None): 112 | if isinstance(data, (list, tuple)): 113 | if dtype is None: 114 | dtype = data[0].dtype 115 | return [nd.array(ele, dtype=dtype, ctx=ctx) for ele in data] 116 | else: 117 | if dtype is None: 118 | dtype = data.dtype 119 | return nd.array(data, dtype=dtype, ctx=ctx) 120 | 121 | 122 | def nd_acc(pred, label): 123 | """Evaluate accuracy using mx.nd.NDArray 124 | 125 | Parameters 126 | ---------- 127 | pred : nd.NDArray 128 | label : nd.NDArray 129 | class_num : int 130 | 131 | Returns 132 | ------- 133 | acc : float 134 | """ 135 | return nd.sum(pred == label).asscalar() / float(pred.size) 136 | 137 | 138 | def nd_f1(pred, label, num_class, average="micro"): 139 | """Evaluate F1 using mx.nd.NDArray 140 | 141 | Parameters 142 | ---------- 143 | pred : nd.NDArray 144 | Shape (num, label_num) or (num,) 145 | label : nd.NDArray 146 | Shape (num, label_num) or (num,) 147 | num_class : int 148 | average : str 149 | 150 | Returns 151 | ------- 152 | f1 : float 153 | """ 154 | if pred.dtype != np.float32: 155 | pred = pred.astype(np.float32) 156 | label = label.astype(np.float32) 157 | assert num_class > 1 158 | assert pred.ndim == label.ndim 159 | if num_class == 2 and average == "micro": 160 | tp = nd.sum((pred == 1) * (label == 1)).asscalar() 161 | fp = nd.sum((pred == 1) * (label == 0)).asscalar() 162 | fn = nd.sum((pred == 0) * (label == 1)).asscalar() 163 | precision = float(tp) / (tp + fp) 164 | recall = float(tp) / (tp + fn) 165 | f1 = 2 * (precision * recall) / (precision + recall) 166 | else: 167 | assert num_class is not None 168 | pred_onehot = nd.one_hot(indices=pred, depth=num_class) 169 | label_onehot = nd.one_hot(indices=label, depth=num_class) 170 | tp = pred_onehot * label_onehot 171 | fp = pred_onehot * (1 - label_onehot) 172 | fn = (1 - pred_onehot) * label_onehot 173 | if average == "micro": 174 | tp = nd.sum(tp).asscalar() 175 | fp = nd.sum(fp).asscalar() 176 | fn = nd.sum(fn).asscalar() 177 | precision = float(tp) / (tp + fp) 178 | recall = float(tp) / (tp + fn) 179 | f1 = 2 * (precision * recall) / (precision + recall) 180 | elif average == "macro": 181 | if tp.ndim == 3: 182 | tp = nd.sum(tp, axis=(0, 1)) 183 | fp = nd.sum(fp, axis=(0, 1)) 184 | fn = nd.sum(fn, axis=(0, 1)) 185 | else: 186 | tp = nd.sum(tp, axis=0) 187 | fp = nd.sum(fp, axis=0) 188 | fn = nd.sum(fn, axis=0) 189 | precision = nd.mean(tp / (tp + fp)).asscalar() 190 | recall = nd.mean(tp / (tp + fn)).asscalar() 191 | f1 = 2 * (precision * recall) / (precision + recall) 192 | else: 193 | raise NotImplementedError 194 | return f1 195 | 196 | 197 | def sklearn_logistic_regression(dataname, 198 | train_embeds, train_labels, 199 | valid_embeds, valid_labels, 200 | test_embeds, test_labels, 201 | max_iter=None, tol=0.001, alpha=0.0001): 202 | if not isinstance(train_embeds, np.ndarray): 203 | train_embeds = train_embeds.asnumpy() 204 | if not isinstance(valid_embeds, np.ndarray): 205 | valid_embeds = valid_embeds.asnumpy() 206 | if not isinstance(test_embeds, np.ndarray): 207 | test_embeds = test_embeds.asnumpy() 208 | if dataname == "ppi": 209 | classifier = MultiOutputClassifier( 210 | SGDClassifier(loss="log", alpha=alpha, n_jobs=-1, max_iter=max_iter, tol=tol)) 211 | classifier.fit(train_embeds, train_labels) 212 | elif dataname == "cora" or dataname == "reddit": 213 | classifier = SGDClassifier(loss="log", alpha=alpha, n_jobs=-1, max_iter=max_iter, tol=tol) 214 | classifier.fit(train_embeds, train_labels) 215 | else: 216 | raise NotImplementedError 217 | train_pred = classifier.predict(train_embeds) 218 | valid_pred = classifier.predict(valid_embeds) 219 | test_pred = classifier.predict(test_embeds) 220 | 221 | train_acc = accuracy_score(y_true=train_labels.reshape((-1,)), y_pred=train_pred.reshape((-1,))) 222 | valid_acc = accuracy_score(y_true=valid_labels.reshape((-1,)), y_pred=valid_pred.reshape((-1,))) 223 | test_acc = accuracy_score(y_true=test_labels.reshape((-1,)), y_pred=test_pred.reshape((-1,))) 224 | 225 | train_f1 = f1_score(y_true=train_labels, y_pred=train_pred, average='micro') 226 | valid_f1 = f1_score(y_true=valid_labels, y_pred=valid_pred, average='micro') 227 | test_f1 = f1_score(y_true=test_labels, y_pred=test_pred, average='micro') 228 | 229 | return train_acc, train_f1, valid_acc, valid_f1, test_acc, test_f1 230 | -------------------------------------------------------------------------------- /GraphSampler/graph_sampler.h: -------------------------------------------------------------------------------- 1 | #ifndef GRAPH_SAMPLER_H_ 2 | #define GRAPH_SAMPLER_H_ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define ASSERT(x) if(!(x)) {std::cout << "Line:" << __LINE__ << " " #x << " does not hold!" << std::endl;exit(0);} 11 | 12 | namespace graph_sampler { 13 | typedef std::unordered_map > GRAPH_DATA_T; 14 | typedef std::mt19937 RANDOM_ENGINE; 15 | const static int MAX_RANDOM_ENGINE_NUM = 128; 16 | const static int MAX_ALLOWED_NODE = std::numeric_limits::max(); 17 | const static long long MAX_ALLOWED_EDGE = std::numeric_limits::max(); 18 | 19 | class SimpleGraph { 20 | public: 21 | SimpleGraph() {} 22 | SimpleGraph(bool undirected, 23 | int max_node_num = std::numeric_limits::max(), 24 | long long max_edge_num = std::numeric_limits::max()) { 25 | set_undirected(undirected); 26 | set_max(max_node_num, max_edge_num); 27 | } 28 | 29 | void set_undirected(bool undirected) { 30 | undirected_ = undirected; 31 | } 32 | 33 | void set_max(int max_node_num, 34 | long long max_edge_num) { 35 | max_node_num_ = max_node_num; 36 | max_edge_num_ = max_edge_num; 37 | if (max_node_num_ < 0) max_node_num_ = MAX_ALLOWED_NODE; 38 | if (max_edge_num_ < 0) max_edge_num_ = MAX_ALLOWED_EDGE; 39 | } 40 | 41 | bool undirected() const { return undirected_; } 42 | int node_num() const { return node_num_; } 43 | int edge_num() const { return edge_num_; } 44 | const GRAPH_DATA_T* data() const { return &data_; } 45 | 46 | bool is_full() { 47 | return edge_num_ >= max_edge_num_ || node_num_ >= max_node_num_; 48 | } 49 | 50 | bool has_node(int node) { 51 | return data_.find(node) != data_.end(); 52 | } 53 | 54 | bool insert_new_node(int node) { 55 | if (is_full()) { 56 | return false; 57 | } 58 | GRAPH_DATA_T::iterator node_it = data_.find(node); 59 | if (node_it == data_.end()) { 60 | data_[node] = std::unordered_set(); 61 | node_num_++; 62 | } 63 | return true; 64 | } 65 | 66 | bool insert_new_edge(std::pair edge) { 67 | if(is_full()) { 68 | return false; 69 | } 70 | // 1. Insert the start point and end point 71 | GRAPH_DATA_T::iterator start_node_it = data_.find(edge.first); 72 | GRAPH_DATA_T::iterator end_node_it = data_.find(edge.second); 73 | bool has_insert_start = false; 74 | if (start_node_it == data_.end()) { 75 | has_insert_start = true; 76 | data_[edge.first] = std::unordered_set(); 77 | node_num_++; 78 | start_node_it = data_.find(edge.first); 79 | } 80 | if(end_node_it == data_.end()) { 81 | // Deal with the special case that the graph will be full after inserting the first node 82 | if (has_insert_start && is_full()) { 83 | data_.erase(start_node_it); 84 | node_num_--; 85 | return false; 86 | } 87 | data_[edge.second] = std::unordered_set(); 88 | node_num_++; 89 | end_node_it = data_.find(edge.second); 90 | } 91 | if (edge.second == edge.first) return true; // Return if the edge is a self-loop 92 | if(start_node_it->second.find(edge.second) == start_node_it->second.end()) { 93 | start_node_it->second.insert(edge.second); 94 | edge_num_++; 95 | } 96 | if (undirected_) { 97 | if (end_node_it->second.find(edge.first) == end_node_it->second.end()) { 98 | end_node_it->second.insert(edge.first); 99 | } 100 | } 101 | return true; 102 | } 103 | 104 | bool insert_nodes(const std::vector &ids) { 105 | std::vector inserted_ids; 106 | for (int id: ids) { 107 | if(is_full()) { 108 | for(int insert_id: inserted_ids) { 109 | data_.erase(insert_id); 110 | } 111 | return false; 112 | } 113 | if(data_.find(id) != data_.end()) { 114 | continue; 115 | } else { 116 | inserted_ids.push_back(id); 117 | data_[id] = std::unordered_set(); 118 | node_num_++; 119 | } 120 | } 121 | return true; 122 | } 123 | 124 | void convert_to_csr(std::vector *end_points, 125 | std::vector *ind_ptr, 126 | std::vector *node_ids, 127 | const int* src_node_ids, 128 | int src_node_size) { 129 | int shift = 0; 130 | std::unordered_map node_id_map; 131 | int counter = 0; 132 | for(const auto& ele: data_) { 133 | node_id_map[ele.first] = counter; 134 | counter++; 135 | } 136 | for (const auto &ele: data_) { 137 | node_ids->push_back(src_node_ids[ele.first]); 138 | ind_ptr->push_back(shift); 139 | for (int node: ele.second) { 140 | end_points->push_back(node_id_map[node]); 141 | shift++; 142 | } 143 | } 144 | ind_ptr->push_back(shift); 145 | } 146 | private: 147 | int max_node_num_ = MAX_ALLOWED_NODE; 148 | long long max_edge_num_ = MAX_ALLOWED_EDGE; 149 | int node_num_ = 0; 150 | long long edge_num_ = 0; 151 | GRAPH_DATA_T data_; 152 | bool undirected_ = true; 153 | }; 154 | 155 | class GraphSampler { 156 | public: 157 | GraphSampler(int seed_id=-1) { 158 | set_seed(seed_id); 159 | } 160 | 161 | void set_seed(int seed_id) { 162 | std::vector seeds(MAX_RANDOM_ENGINE_NUM); 163 | int u_seed_id = seed_id; 164 | if(seed_id < 0) { 165 | //Randomly set seed of the engine 166 | std::random_device rd; 167 | std::uniform_int_distribution dist(0, 100000); 168 | u_seed_id = dist(rd); 169 | } 170 | RANDOM_ENGINE base_engine; 171 | base_engine.seed(u_seed_id); 172 | std::unordered_map pool; 173 | for(int i = 0; i < MAX_RANDOM_ENGINE_NUM; i++) { 174 | std::uniform_int_distribution dist(i, 100000000); 175 | int val = dist(base_engine); 176 | if(pool.find(val) != pool.end()) { 177 | eng_[i].seed(pool[val]); 178 | } else { 179 | eng_[i].seed(val); 180 | } 181 | if(pool.find(i) != pool.end()) { 182 | pool[val] = pool[i]; 183 | } else { 184 | pool[val] = i; 185 | } 186 | } 187 | } 188 | 189 | /* 190 | Sampling the graph by randomwalk. 191 | At every step, we will return to the original node with return_p. Otherwise, we will jump randomly to a conneted node. 192 | See [KDD06] Sampling from Large Graphs 193 | ------------------------------------------------------------------ 194 | Params: 195 | src_end_points: end points in the source graph 196 | src_ind_ptr: ind ptr in the source graph 197 | src_node_ids: node ids of the source graph 198 | src_undirected: whether the source graph is undirected 199 | src_node_num: number of nodes in the source graph 200 | initial_node: initial node of the random walk, if set to negative, the initial node will be chosen randomly from the original graph 201 | walk_length: length of the random walk 202 | return_prob: the returning probability 203 | max_node_num: the maximum node num allowed in the sampled subgraph 204 | max_edge_num: the maximum edge num allowed in the sampled subgraph 205 | ------------------------------------------------------------------ 206 | Return: 207 | subgraph: the sampled graph 208 | */ 209 | SimpleGraph* random_walk(const int* src_end_points, 210 | const int* src_ind_ptr, 211 | const int* src_node_ids, 212 | bool src_undirected, 213 | int src_node_num, 214 | int initial_node, 215 | int walk_length=10, 216 | double return_prob=0.15, 217 | int max_node_num=std::numeric_limits::max(), 218 | long long max_edge_num = std::numeric_limits::max(), 219 | int eng_id=0); 220 | /* 221 | Draw edges from the graph by negative sampling. 222 | 223 | */ 224 | void uniform_neg_sampling(const int* src_end_points, 225 | const int* src_ind_ptr, 226 | const int* target_indices, 227 | int nnz, 228 | int node_num, 229 | int dst_node_num, 230 | float neg_sample_scale, 231 | bool replace, 232 | int** dst_end_points, 233 | int** dst_ind_ptr, 234 | int** dst_edge_label, 235 | int** dst_edge_count, 236 | int* dst_nnz); 237 | 238 | /* 239 | Begin random walk from a given index 240 | */ 241 | void get_random_walk_nodes(const int* src_end_points, 242 | const int* src_ind_ptr, 243 | int nnz, 244 | int node_num, 245 | int initial_node, 246 | int max_node_num, 247 | int walk_length, 248 | std::vector* dst_indices); 249 | 250 | /* 251 | Randomly select the neighborhoods and merge 252 | */ 253 | void random_sel_neighbor_and_merge(const int* src_end_points, 254 | const int* src_ind_ptr, 255 | const int* src_node_ids, 256 | const int* sel_indices, 257 | int nnz, 258 | int sel_node_num, 259 | int neighbor_num, 260 | float neighbor_frac, 261 | bool sample_all, 262 | bool replace, 263 | std::vector* dst_end_points, 264 | std::vector* dst_ind_ptr, 265 | std::vector* merged_node_ids, 266 | std::vector* indices_in_merged); 267 | 268 | private: 269 | RANDOM_ENGINE eng_[MAX_RANDOM_ENGINE_NUM]; 270 | }; 271 | 272 | void slice_csr_mat(const int* src_end_points, 273 | const float* src_values, 274 | const int* src_ind_ptr, 275 | const int* src_row_ids, 276 | const int* src_col_ids, 277 | int src_row_num, 278 | int src_col_num, 279 | int src_nnz, 280 | const int* sel_row_indices, 281 | const int* sel_col_indices, 282 | int dst_row_num, 283 | int dst_col_num, 284 | int** dst_end_points, 285 | float** dst_values, 286 | int** dst_ind_ptr, 287 | int** dst_row_ids, 288 | int** dst_col_ids, 289 | int* dst_nnz); 290 | 291 | } // namespace graph_sampler 292 | #endif -------------------------------------------------------------------------------- /mxgraph/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import yaml 4 | import logging 5 | from collections import OrderedDict, namedtuple 6 | from mxgraph.helpers.ordered_easydict import OrderedEasyDict as edict 7 | 8 | __C = edict() 9 | cfg = __C # type: edict() 10 | 11 | # Random seed 12 | __C.MX_SEED = 12345 13 | __C.NPY_SEED = 12345 14 | 15 | # Project directory, since config.py is supposed to be in $ROOT_DIR/mxgraph 16 | __C.ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 17 | 18 | __C.DATASET_PATH = os.path.join(__C.ROOT_DIR, 'datasets') 19 | 20 | # DATA NAME 21 | # Used by symbols factories who need to adjust for different 22 | # inputs based on dataset used. Should be set by the script. 23 | __C.DATA_NAME = 'ppi' # choice: ppi; reddit; movielens 24 | __C.DATA_VERSION = "" 25 | __C.SPLIT_TRAINING = False 26 | __C.TRAIN_SPLIT_NUM = 20 27 | __C.LOAD_WALKS = False 28 | 29 | __C.AGGREGATOR = edict() 30 | __C.AGGREGATOR.ACTIVATION = 'relu' #'leaky' 31 | 32 | __C.AGGREGATOR.GRAPHPOOL = edict() 33 | __C.AGGREGATOR.GRAPHPOOL.ARGS = ["out_units", "mid_units", "mid_layer_num"] 34 | __C.AGGREGATOR.GRAPHPOOL.POOL_TYPE = "avg" # Can be "max", "avg", "sum" or "mixed" (sum and mixed is deprecated!) 35 | 36 | __C.AGGREGATOR.HETERGRAPHPOOL = edict() 37 | __C.AGGREGATOR.HETERGRAPHPOOL.ARGS = ["out_units", "mid_units", "mid_layer_num", "num_set"] 38 | __C.AGGREGATOR.HETERGRAPHPOOL.POOL_TYPE = "avg" # Can be "max", "avg", "sum" or "mixed" (sum and mixed is deprecated!) 39 | 40 | __C.AGGREGATOR.GRAPH_WEIGHTED_SUM = edict() 41 | __C.AGGREGATOR.GRAPH_WEIGHTED_SUM.ARGS = ["out_units", "mid_units", "attend_units"] 42 | __C.AGGREGATOR.GRAPH_WEIGHTED_SUM.ATTEND_W_DROPOUT = 0.0 43 | __C.AGGREGATOR.GRAPH_WEIGHTED_SUM.DIVIDE_SIZE = True 44 | __C.AGGREGATOR.GRAPH_WEIGHTED_SUM.WEIGHT_ACT = 'sigmoid' 45 | 46 | __C.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM = edict() 47 | __C.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM.ARGS = ["out_units", "mid_units", "attend_units", "K"] 48 | __C.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM.ATTEND_W_DROPOUT = 0.0 49 | __C.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM.DIVIDE_SIZE = True 50 | __C.AGGREGATOR.GRAPH_MULTI_WEIGHTED_SUM.WEIGHT_ACT = 'sigmoid' 51 | 52 | __C.AGGREGATOR.MUGGA = edict() 53 | __C.AGGREGATOR.MUGGA.ARGS = ["out_units", "attend_units", "value_units", "K", "context_units", "context_layer_num"] 54 | __C.AGGREGATOR.MUGGA.USE_EDGE = False 55 | __C.AGGREGATOR.MUGGA.RESCALE_INNERPRODUCT = True 56 | __C.AGGREGATOR.MUGGA.ATTEND_W_DROPOUT = 0.5 57 | __C.AGGREGATOR.MUGGA.CONTEXT = edict() # TODO(sxjscience) Add flag to enable/disable dense connection in context network 58 | __C.AGGREGATOR.MUGGA.CONTEXT.USE_SUM_POOL = False ## set to be false and do not change it 59 | __C.AGGREGATOR.MUGGA.CONTEXT.USE_MAX_POOL = True 60 | __C.AGGREGATOR.MUGGA.CONTEXT.USE_AVG_POOL = True 61 | __C.AGGREGATOR.MUGGA.CONTEXT.USE_GATE = False 62 | __C.AGGREGATOR.MUGGA.CONTEXT.USE_SHARPNESS = False 63 | 64 | __C.AGGREGATOR.BIGRAPHPOOL = edict() 65 | __C.AGGREGATOR.BIGRAPHPOOL.ARGS = ["out_units", "mid_units", "num_node_set", "num_edge_set"] 66 | __C.AGGREGATOR.BIGRAPHPOOL.POOL_TYPE = "avg" # Can be "max" and "avg" 67 | __C.AGGREGATOR.BIGRAPHPOOL.ACCUM_TYPE = "sum" # Can be "stack" and "sum" 68 | 69 | 70 | __C.BI_GRAPH=edict() 71 | __C.BI_GRAPH.MODEL = edict() 72 | #__C.BI_GRAPH.MODEL.LOSS_TYPE = "regression" 73 | __C.BI_GRAPH.MODEL.FEA_EMBED_UNITS = 500 74 | #__C.BI_GRAPH.MODEL.FIRST_EMBED_UNITS = 256 75 | __C.BI_GRAPH.MODEL.AGGREGATOR_ARGS_LIST = [["BiGraphPoolAggregator", [None, 500, None, 5]]] 76 | __C.BI_GRAPH.MODEL.OUT_NODE_EMBED = 75 77 | __C.BI_GRAPH.MODEL.DROPOUT_RATE_LIST = [0.7] 78 | __C.BI_GRAPH.MODEL.DENSE_CONNECT = False 79 | __C.BI_GRAPH.MODEL.L2_NORMALIZATION = False 80 | __C.BI_GRAPH.MODEL.EVERY_LAYER_L2_NORMALIZATION = False 81 | #__C.HETER_GRAPH.MODEL.PRED_HIDDEN_DIM = 64 82 | 83 | __C.BI_GRAPH.MODEL.TRAIN = edict() 84 | __C.BI_GRAPH.MODEL.TRAIN.BATCH_SIZE = 128 85 | __C.BI_GRAPH.MODEL.TRAIN.GRAPH_SAMPLER_ARGS = ["all", 1] # ["fixed", [50, 20]] # Can be all, fraction, fixed, ... 86 | __C.BI_GRAPH.MODEL.TRAIN.VALID_ITER = 500 87 | __C.BI_GRAPH.MODEL.TRAIN.TEST_ITER = 1 88 | __C.BI_GRAPH.MODEL.TRAIN.MAX_ITER = 100000 89 | __C.BI_GRAPH.MODEL.TRAIN.OPTIMIZER = "adam" 90 | __C.BI_GRAPH.MODEL.TRAIN.LR = 1E-2 # initial learning rate 91 | __C.BI_GRAPH.MODEL.TRAIN.MIN_LR = 1E-5 # Minimum learning rate 92 | __C.BI_GRAPH.MODEL.TRAIN.DECAY_PATIENCE = 5 # Patience of the lr decay. If no better train loss occurs for DECAY_PATIENCE epochs, the lr will be multplied by lr_decay 93 | __C.BI_GRAPH.MODEL.TRAIN.EARLY_STOPPING_PATIENCE = 10 # Patience of early stopping 94 | __C.BI_GRAPH.MODEL.TRAIN.LR_DECAY_FACTOR = 0.5 95 | __C.BI_GRAPH.MODEL.TRAIN.GRAD_CLIP = 5.0 96 | __C.BI_GRAPH.MODEL.TRAIN.WD = 0.0 97 | 98 | 99 | __C.HETER_GRAPH=edict() 100 | __C.HETER_GRAPH.MODEL = edict() 101 | __C.HETER_GRAPH.MODEL.LOSS_TYPE = "regression" 102 | __C.HETER_GRAPH.MODEL.FEA_EMBED_UNITS = 256 103 | #__C.HETER_GRAPH.MODEL.FIRST_EMBED_UNITS = 256 104 | __C.HETER_GRAPH.MODEL.AGGREGATOR_ARGS_LIST = [["HeterGraphPoolAggregator", [128, 512, 1, 3]], 105 | ["HeterGraphPoolAggregator", [128, 512, 1, 3]]] 106 | __C.HETER_GRAPH.MODEL.OUT_NODE_EMBED = 128 107 | __C.HETER_GRAPH.MODEL.DROPOUT_RATE_LIST = [0.5, 0.5] 108 | __C.HETER_GRAPH.MODEL.DENSE_CONNECT = False 109 | __C.HETER_GRAPH.MODEL.L2_NORMALIZATION = True 110 | __C.HETER_GRAPH.MODEL.EVERY_LAYER_L2_NORMALIZATION = True 111 | #__C.HETER_GRAPH.MODEL.PRED_HIDDEN_DIM = 64 112 | 113 | __C.HETER_GRAPH.MODEL.TRAIN = edict() 114 | __C.HETER_GRAPH.MODEL.TRAIN.BATCH_SIZE = 128 115 | __C.HETER_GRAPH.MODEL.TRAIN.GRAPH_SAMPLER_ARGS = ["all", 2] # ["fixed", [50, 20]] # Can be all, fraction, fixed, ... 116 | __C.HETER_GRAPH.MODEL.TRAIN.VALID_ITER = 625 117 | __C.HETER_GRAPH.MODEL.TRAIN.TEST_ITER = 1 118 | __C.HETER_GRAPH.MODEL.TRAIN.MAX_ITER = 100000 119 | __C.HETER_GRAPH.MODEL.TRAIN.OPTIMIZER = "adam" 120 | __C.HETER_GRAPH.MODEL.TRAIN.LR = 1E-3 # initial learning rate 121 | __C.HETER_GRAPH.MODEL.TRAIN.MIN_LR = 1E-5 # Minimum learning rate 122 | __C.HETER_GRAPH.MODEL.TRAIN.DECAY_PATIENCE = 5 # Patience of the lr decay. If no better train loss occurs for DECAY_PATIENCE epochs, the lr will be multplied by lr_decay 123 | __C.HETER_GRAPH.MODEL.TRAIN.EARLY_STOPPING_PATIENCE = 10 # Patience of early stopping 124 | __C.HETER_GRAPH.MODEL.TRAIN.LR_DECAY_FACTOR = 0.5 125 | __C.HETER_GRAPH.MODEL.TRAIN.GRAD_CLIP = 5.0 126 | __C.HETER_GRAPH.MODEL.TRAIN.WD = 0.0 127 | 128 | 129 | __C.STATIC_GRAPH = edict() 130 | __C.STATIC_GRAPH.MODEL = edict() 131 | __C.STATIC_GRAPH.MODEL.TYP = 'supervised' ## This hyperparameter does not have any meaning but for logging 132 | __C.STATIC_GRAPH.MODEL.FEATURE_NORMALIZE = False 133 | if __C.DATA_NAME == 'ppi': 134 | __C.STATIC_GRAPH.MODEL.FIRST_EMBED_UNITS = 64 135 | elif __C.DATA_NAME == 'reddit': 136 | __C.STATIC_GRAPH.MODEL.FIRST_EMBED_UNITS = 256 137 | __C.STATIC_GRAPH.MODEL.AGGREGATOR_ARGS_LIST = [["MuGGA", [128, 16, 16, 8, 16, 3]], 138 | ["MuGGA", [128, 16, 16, 8, 16, 3]]] 139 | __C.STATIC_GRAPH.MODEL.DROPOUT_RATE_LIST = [0.5, 0.5] # dropout rate (1 - keep probability)' 140 | __C.STATIC_GRAPH.MODEL.DENSE_CONNECT = False 141 | __C.STATIC_GRAPH.MODEL.L2_NORMALIZATION = False 142 | __C.STATIC_GRAPH.MODEL.EVERY_LAYER_L2_NORMALIZATION = False 143 | 144 | # The following elements are generally used in unsupervised learning 145 | __C.STATIC_GRAPH.MODEL.EMBED_DIM = 128 146 | __C.STATIC_GRAPH.MODEL.NEG_WEIGHT = 1.0 147 | __C.STATIC_GRAPH.MODEL.TRAIN_NEG_SAMPLE_SCALE = 20 148 | __C.STATIC_GRAPH.MODEL.TRAIN_NEG_SAMPLE_REPLACE = False 149 | __C.STATIC_GRAPH.MODEL.VALID_NEG_SAMPLE_SCALE = 50 150 | __C.STATIC_GRAPH.MODEL.TEST_NEG_SAMPLE_SCALE = 50 151 | 152 | __C.STATIC_GRAPH.MODEL.TRAIN = edict() 153 | __C.STATIC_GRAPH.MODEL.TRAIN.BATCH_SIZE = 512 154 | __C.STATIC_GRAPH.MODEL.TRAIN.GRAPH_SAMPLER_ARGS = ["fixed", [25, 10]] # Can be all, fraction, fixed, ... 155 | __C.STATIC_GRAPH.MODEL.TRAIN.VALID_ITER = 1 156 | __C.STATIC_GRAPH.MODEL.TRAIN.TEST_ITER = 1 157 | __C.STATIC_GRAPH.MODEL.TRAIN.MAX_ITER = 100000 158 | __C.STATIC_GRAPH.MODEL.TRAIN.OPTIMIZER = "adam" 159 | __C.STATIC_GRAPH.MODEL.TRAIN.LR = 1E-3 # initial learning rate 160 | __C.STATIC_GRAPH.MODEL.TRAIN.MIN_LR = 1E-3 # Minimum learning rate 161 | __C.STATIC_GRAPH.MODEL.TRAIN.DECAY_PATIENCE = 15 # Patience of the lr decay. If no better train loss occurs for DECAY_PATIENCE epochs, the lr will be multplied by lr_decay 162 | __C.STATIC_GRAPH.MODEL.TRAIN.EARLY_STOPPING_PATIENCE = 30 # Patience of early stopping 163 | __C.STATIC_GRAPH.MODEL.TRAIN.LR_DECAY_FACTOR = 0.5 164 | __C.STATIC_GRAPH.MODEL.TRAIN.GRAD_CLIP = 1.0 165 | __C.STATIC_GRAPH.MODEL.TRAIN.WD = 0.0 166 | 167 | __C.STATIC_GRAPH.MODEL.TEST = edict() 168 | __C.STATIC_GRAPH.MODEL.TEST.BATCH_SIZE = 512 169 | __C.STATIC_GRAPH.MODEL.TEST.SAMPLE_NUM = 5 170 | 171 | __C.SPATIOTEMPORAL_GRAPH = edict() 172 | __C.SPATIOTEMPORAL_GRAPH.IN_LENGTH = 12 173 | __C.SPATIOTEMPORAL_GRAPH.OUT_LENGTH = 12 174 | __C.SPATIOTEMPORAL_GRAPH.USE_COORDINATES = True 175 | __C.SPATIOTEMPORAL_GRAPH.MODEL = edict() 176 | __C.SPATIOTEMPORAL_GRAPH.MODEL.RNN_TYPE = "RNN" 177 | __C.SPATIOTEMPORAL_GRAPH.MODEL.AGGREGATOR_ARGS_LIST = [["MuGGA", [64, 8, 16, 4, 32, 1]], 178 | ["MuGGA", [64, 8, 16, 4, 32, 1]]] 179 | __C.SPATIOTEMPORAL_GRAPH.MODEL.AGGREGATION_TYPE = "all" 180 | __C.SPATIOTEMPORAL_GRAPH.MODEL.ADJ_PREPROCESS = 'undirected' 181 | __C.SPATIOTEMPORAL_GRAPH.MODEL.DROPOUT_RATE = 0.0 182 | __C.SPATIOTEMPORAL_GRAPH.MODEL.DIFFUSSION_STEP = 1 183 | __C.SPATIOTEMPORAL_GRAPH.MODEL.SHARPNESS_LAMBDA = 0.0 184 | __C.SPATIOTEMPORAL_GRAPH.MODEL.DIVERSITY_LAMBDA = 0.0 185 | __C.SPATIOTEMPORAL_GRAPH.MODEL.USE_EDGE = False 186 | 187 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN = edict() 188 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.BATCH_SIZE = 64 189 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.MAX_EPOCH = 100 190 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.OPTIMIZER = "adam" 191 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.LR = 1E-3 # initial learning rate 192 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.MIN_LR = 1E-5 # Minimum learning rate 193 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.SCHEDULED_SAMPLING = edict() 194 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.SCHEDULED_SAMPLING.TAU = 3000 # tau / (tau + exp(iter / tau)) 195 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.INITIAL_EPOCHS = 20 196 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.DECAY_PATIENCE = 10 # Patience of the lr decay. Decay the learning rate every DECAY_PATIENCE epochs 197 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.EARLY_STOPPING_PATIENCE = 5 # Patience of early stopping 198 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.LR_DECAY_FACTOR = 0.1 199 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.GRAD_CLIP = 5.0 200 | __C.SPATIOTEMPORAL_GRAPH.MODEL.TRAIN.WD = 0.0 201 | 202 | 203 | 204 | def _merge_two_config(user_cfg, default_cfg): 205 | """ Merge user's config into default config dictionary, clobbering the 206 | options in b whenever they are also specified in a. 207 | Need to ensure the type of two val under same key are the same 208 | Do recursive merge when encounter hierarchical dictionary 209 | """ 210 | if type(user_cfg) is not edict: 211 | return 212 | for key, val in user_cfg.items(): 213 | # Since user_cfg is a sub-file of default_cfg 214 | if key not in default_cfg: 215 | raise KeyError('{} is not a valid config key'.format(key)) 216 | 217 | if (type(default_cfg[key]) is not type(val) and 218 | default_cfg[key] is not None): 219 | if isinstance(default_cfg[key], np.ndarray): 220 | val = np.array(val, dtype=default_cfg[key].dtype) 221 | elif isinstance(default_cfg[key], (int, float)) and isinstance(val, (int, float)): 222 | pass 223 | else: 224 | raise ValueError( 225 | 'Type mismatch ({} vs. {}) ' 226 | 'for config key: {}'.format(type(default_cfg[key]), 227 | type(val), key)) 228 | # Recursive merge config 229 | if type(val) is edict: 230 | try: 231 | _merge_two_config(user_cfg[key], default_cfg[key]) 232 | except: 233 | print('Error under config key: {}'.format(key)) 234 | raise 235 | else: 236 | default_cfg[key] = val 237 | 238 | def cfg_from_file(file_name, target=__C): 239 | """ Load a config file and merge it into the default options. 240 | """ 241 | import yaml 242 | with open(file_name, 'r') as f: 243 | print('Loading YAML config file from %s' %f) 244 | yaml_cfg = edict(yaml.load(f)) 245 | 246 | _merge_two_config(yaml_cfg, target) 247 | 248 | 249 | def ordered_dump(data=__C, stream=None, Dumper=yaml.SafeDumper, **kwds): 250 | class OrderedDumper(Dumper): 251 | pass 252 | 253 | def _dict_representer(dumper, data): 254 | return dumper.represent_mapping( 255 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 256 | data.items(), flow_style=False) 257 | 258 | def _ndarray_representer(dumper, data): 259 | return dumper.represent_list(data.tolist()) 260 | 261 | OrderedDumper.add_representer(OrderedDict, _dict_representer) 262 | OrderedDumper.add_representer(edict, _dict_representer) 263 | OrderedDumper.add_representer(np.ndarray, _ndarray_representer) 264 | return yaml.dump(data, stream, OrderedDumper, **kwds) 265 | 266 | 267 | def save_cfg_file(file_path, source=__C): 268 | source = source.copy() 269 | masked_keys = ['DATASET_PATH', 'ROOT_DIR'] 270 | for key in masked_keys: 271 | if key in source: 272 | del source[key] 273 | delattr(source, key) 274 | with open(file_path, 'w') as f: 275 | logging.info("Save YAML config file to %s" %file_path) 276 | ordered_dump(source, f, yaml.SafeDumper, default_flow_style=None) 277 | 278 | 279 | def save_cfg_dir(dir_path, source=__C): 280 | cfg_count = 0 281 | file_path = os.path.join(dir_path, 'cfg%d.yml' %cfg_count) 282 | while os.path.exists(file_path): 283 | cfg_count += 1 284 | file_path = os.path.join(dir_path, 'cfg%d.yml' % cfg_count) 285 | save_cfg_file(file_path, source) 286 | return cfg_count 287 | 288 | def load_latest_cfg(dir_path, target=__C): 289 | import re 290 | cfg_count = None 291 | source_cfg_path = None 292 | for fname in os.listdir(dir_path): 293 | ret = re.search(r'cfg(\d+)\.yml', fname) 294 | if ret != None: 295 | if cfg_count is None or (int(re.group(1)) > cfg_count): 296 | cfg_count = int(re.group(1)) 297 | source_cfg_path = os.path.join(dir_path, ret.group(0)) 298 | cfg_from_file(file_name=source_cfg_path, target=target) 299 | 300 | 301 | # save_f_name = os.path.join("..", "experiments", "heterogeneous_graph", "baselines", "cfg_template","ml_100k.yml") 302 | # save_cfg_file(save_f_name) 303 | -------------------------------------------------------------------------------- /mxgraph/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as ss 3 | import logging 4 | from mxgraph import _graph_sampler 5 | 6 | ### numpy version of negative sampling TODO out-of-date 7 | class UnsupLossSampler(object): 8 | def __init__(self, num_neg_sample, **kwargs): 9 | """ 10 | Parameters 11 | ---------- 12 | num_neg_sample : Shape scalar 13 | kwargs 14 | """ 15 | self.num_neg_sample = num_neg_sample 16 | 17 | def run(self, orig_end_points, orig_indptr, sample_target_indices): 18 | """ 19 | For training: 20 | orig_end_points and orig_indptr refer to the training subgraph 21 | sample_target_indices are the training nodes 22 | For validation: 23 | orig_end_points and orig_indptr refer to the training+validation subgraph 24 | sample_target_indices are the validation nodes 25 | Parameters 26 | ---------- 27 | orig_end_points : 28 | Shape (num_edges,) 29 | orig_indptr : 30 | Shape (num_nodes+1, ) 31 | sample_target_indices : 32 | Shape (#sample_nodes,) 33 | 34 | Returns 35 | ------- 36 | 37 | """ 38 | self.orig_node_size = orig_indptr.size-1 39 | 40 | ### Fetch positive edges 41 | values = np.ones(shape=orig_end_points.shape, dtype=np.float32) 42 | pos_adj = ss.csr_matrix((values, orig_end_points, orig_indptr), 43 | shape=(self.orig_node_size, self.orig_node_size)) 44 | 45 | ### Negative Sampling 46 | p_unnorm = pos_adj 47 | p_unnorm.data = values * -1. 48 | p_unnorm = p_unnorm.toarray() + 1. ## negative edge: 1.. ; positive edge: 0.0 to compute the sampling probability 49 | np.fill_diagonal(p_unnorm, .0) 50 | p = p_unnorm / p_unnorm.sum(axis=1, keepdims=True) 51 | 52 | row = np.asarray([np.ones(self.num_neg_sample, dtype=np.int32)*node_id for node_id in sample_target_indices]) 53 | column = np.asarray([np.random.choice(self.orig_node_size, 54 | size=self.num_neg_sample, 55 | replace=False, 56 | p=p[node_id]).tolist() 57 | for node_id in sample_target_indices]) 58 | neg_adj = ss.coo_matrix((np.ones(self.num_neg_sample * sample_target_indices.size) * -1.0, 59 | (row.reshape(-1), column.reshape(-1,))), 60 | shape=(self.orig_node_size, self.orig_node_size)).tocsr() 61 | ### Sum pos_adj and neg_adj 62 | all_adj = pos_adj + neg_adj 63 | sampled_adj = all_adj[sample_target_indices] 64 | end_points, indptr, edge_labels = sampled_adj.indices, sampled_adj.indptr, sampled_adj.data 65 | return end_points, indptr, edge_labels 66 | 67 | 68 | class LinkPredEdgeSampler(object): 69 | def __init__(self, neg_sample_scale, replace=False): 70 | """ 71 | 72 | Parameters 73 | ---------- 74 | neg_sample_scale: float 75 | For every node, the number of negative samples will be: 76 | 77 | .. Code:: 78 | 79 | expected_neg_num = ceil(pos_num * neg_sample_scale) 80 | expected_neg_num = min(expected_neg_num, total_num - pos_num) 81 | 82 | If this value is set to be negative, no sampling will be performed and 83 | all the end_points will be returned. 84 | replace : bool 85 | Whether to sample the negative edges with replacement 86 | """ 87 | self._neg_sample_scale = neg_sample_scale 88 | self._replace = replace 89 | 90 | def sample_by_indices(self, G, chosen_node_indices): 91 | """Sample the G with the chosen_node_indices 92 | 93 | Parameters 94 | ---------- 95 | G : SimpleGraph 96 | chosen_node_indices : np.ndarray 97 | 98 | Returns 99 | ------- 100 | end_points : np.ndarray 101 | indptr : np.ndarray 102 | edge_label : np.ndarray 103 | edge_count : np.ndarray 104 | """ 105 | end_points, indptr, edge_label, edge_count = \ 106 | _graph_sampler.uniform_neg_sampling(G.end_points.astype(np.int32), 107 | G.ind_ptr.astype(np.int32), 108 | chosen_node_indices.astype(np.int32), 109 | float(self._neg_sample_scale), 110 | int(self._replace)) 111 | return end_points, indptr, edge_label, edge_count 112 | 113 | def sample_by_id(self, G, chosen_node_ids): 114 | """Sample the G with the chosen_node_ids 115 | 116 | Parameters 117 | ---------- 118 | G : SimpleGraph 119 | chosen_node_ids : np.ndarray 120 | 121 | Returns 122 | ------- 123 | end_points : np.ndarray 124 | indptr : np.ndarray 125 | edge_label : np.ndarray 126 | edge_count : np.ndarray 127 | """ 128 | return self.sample_by_indices(G, G.reverse_map(chosen_node_ids)) 129 | 130 | 131 | class BaseHierarchicalNodeSampler(object): 132 | def __init__(self, layer_num): 133 | self._layer_num = layer_num 134 | 135 | @property 136 | def layer_num(self): 137 | return self._layer_num 138 | 139 | def layer_sample_and_merge(self, G, sel_indices, depth): 140 | """ 141 | 142 | Parameters 143 | ---------- 144 | G : SimpleGraph 145 | sel_indices : np.ndarray 146 | depth : int 147 | 148 | Returns 149 | ------- 150 | end_points : np.ndarray 151 | indptr : np.ndarray 152 | merged_node_ids : np.ndarray 153 | indices_in_merged : np.ndarray 154 | """ 155 | raise NotImplementedError 156 | 157 | def sample_by_indices(self, G, base_node_indices): 158 | """Note that the order of our samples should be from the lowest layer to the highest layer 159 | 160 | Parameters 161 | ---------- 162 | G : SimpleGraph 163 | base_node_indices : np.ndarray 164 | 165 | Returns 166 | ------- 167 | indices_in_merged_l : list 168 | The first element will be the indices of the 0th layer in the original graph 169 | end_points_l : list 170 | indptr_l : list 171 | node_ids_l : list 172 | The original ids of the indices 173 | """ 174 | base_node_indices = base_node_indices.astype(np.int32) 175 | indices_in_merged_l = [None for _ in range(self._layer_num + 1)] 176 | end_points_l = [None for _ in range(self._layer_num)] 177 | indptr_l = [None for _ in range(self._layer_num)] 178 | node_ids_l = [None for _ in range(self._layer_num + 1)] 179 | node_ids_l[self._layer_num] = G.node_ids[base_node_indices] 180 | for i in range(self._layer_num - 1, -1, -1): 181 | end_points, indptr, merged_node_ids, indices_in_merged =\ 182 | self.layer_sample_and_merge(G, base_node_indices, self._layer_num - 1 - i) 183 | base_node_indices = G.reverse_map(merged_node_ids) 184 | indices_in_merged_l[i + 1] = indices_in_merged 185 | node_ids_l[i] = merged_node_ids 186 | end_points_l[i] = end_points 187 | indptr_l[i] = indptr 188 | indices_in_merged_l[0] = base_node_indices 189 | return indices_in_merged_l, end_points_l, indptr_l, node_ids_l 190 | 191 | def sample_by_indices_with_edge_weight(self, G, base_node_indices): 192 | """Note that the order of our samples should be from the lowest layer to the highest layer 193 | 194 | Parameters 195 | ---------- 196 | G : SimpleGraph 197 | base_node_indices : np.ndarray 198 | 199 | Returns 200 | ------- 201 | indices_in_merged_l : list 202 | The first element will be the indices of the 0th layer in the original graph 203 | end_points_l : list 204 | indptr_l : list 205 | node_ids_l : list 206 | The original ids of the indices 207 | """ 208 | base_node_indices = base_node_indices.astype(np.int32) 209 | indices_in_merged_l = [None for _ in range(self._layer_num + 1)] 210 | end_points_l = [None for _ in range(self._layer_num)] 211 | indptr_l = [None for _ in range(self._layer_num)] 212 | end_points_edge_weight_l = [None for _ in range(self._layer_num)] 213 | node_ids_l = [None for _ in range(self._layer_num + 1)] 214 | node_ids_l[self._layer_num] = G.node_ids[base_node_indices] 215 | for i in range(self._layer_num - 1, -1, -1): 216 | end_points, indptr, merged_node_ids, indices_in_merged =\ 217 | self.layer_sample_and_merge(G, base_node_indices, self._layer_num - 1 - i) 218 | base_node_indices = G.reverse_map(merged_node_ids) 219 | indices_in_merged_l[i + 1] = indices_in_merged 220 | node_ids_l[i] = merged_node_ids 221 | end_points_l[i] = end_points 222 | indptr_l[i] = indptr 223 | # print("previous base_node_indices", G.reverse_map(node_ids_l[i+1]).shape, G.reverse_map(node_ids_l[i+1])) 224 | # print("indices_in_merged", indices_in_merged.shape, indices_in_merged) 225 | # print("base_node_indices", base_node_indices.shape, base_node_indices) 226 | sampled_sub_adj = G.adj.submat(row_indices=G.reverse_map(node_ids_l[i+1]), col_indices=G.reverse_map(node_ids_l[i])) 227 | # print("**********************************") 228 | # print("sampled_sub_adj.ind_ptr", sampled_sub_adj.ind_ptr.shape, sampled_sub_adj.ind_ptr) 229 | # print("sampled indptr", indptr.shape, indptr) 230 | # print("**********************************") 231 | # print("sampled_sub_adj.end_points", sampled_sub_adj.end_points.shape, sampled_sub_adj.end_points) 232 | # print("sampled end_points", end_points.shape, end_points) 233 | # print("**********************************") 234 | # print("sampled_sub_adj.values", sampled_sub_adj.values.shape, sampled_sub_adj.values) 235 | end_points_edge_weight_l[i] =sampled_sub_adj.values 236 | #print("=============================================") 237 | 238 | indices_in_merged_l[0] = base_node_indices 239 | return indices_in_merged_l, end_points_l, indptr_l, node_ids_l, end_points_edge_weight_l 240 | 241 | 242 | # def sample_by_id(self, G, base_node_ids): 243 | # return self.sample_by_indices(G, base_node_ids) 244 | 245 | 246 | class FractionNeighborSampler(BaseHierarchicalNodeSampler): 247 | def __init__(self, layer_num, neighbor_fraction=None, replace=False): 248 | super(FractionNeighborSampler, self).__init__(layer_num=layer_num) 249 | self._neighbor_fraction = neighbor_fraction 250 | self._replace = replace 251 | assert len(self._neighbor_fraction) == layer_num 252 | for ele in self._neighbor_fraction: 253 | assert ele > 0 254 | 255 | def layer_sample_and_merge(self, G, sel_indices, depth): 256 | return _graph_sampler.random_sel_neighbor_and_merge( 257 | G.end_points, G.ind_ptr, G.node_ids, sel_indices.astype(np.int32), 258 | -1, np.float32(self._neighbor_fraction[depth]), 0, int(self._replace)) 259 | 260 | 261 | class FixedNeighborSampler(BaseHierarchicalNodeSampler): 262 | def __init__(self, layer_num, neighbor_num=None, replace=False): 263 | super(FixedNeighborSampler, self).__init__(layer_num=layer_num) 264 | self._neighbor_num = neighbor_num 265 | self._replace = replace 266 | assert len(self._neighbor_num) == self._layer_num 267 | for ele in self._neighbor_num: 268 | assert ele > 0 269 | 270 | def layer_sample_and_merge(self, G, sel_indices, depth): 271 | return _graph_sampler.random_sel_neighbor_and_merge( 272 | G.end_points, G.ind_ptr, G.node_ids, sel_indices.astype(np.int32), 273 | self._neighbor_num[depth], -1.0, 0, int(self._replace)) 274 | 275 | 276 | 277 | class AllNeighborSampler(BaseHierarchicalNodeSampler): 278 | def _npy_layer_sample_and_merge(self, G, base_node_indices): 279 | end_points = [] 280 | indptr = [] 281 | merged_node_ids = [] 282 | indices_in_merged = [] 283 | indptr.append(0) 284 | for ind in base_node_indices: 285 | end_points.extend(G.end_points[G.ind_ptr[ind]:G.ind_ptr[ind + 1]].tolist()) 286 | indptr.append(indptr[-1] + G.ind_ptr[ind + 1] - G.ind_ptr[ind]) 287 | counter = 0 288 | ind_dict = dict() 289 | for i, val in enumerate(end_points): 290 | if val not in ind_dict: 291 | ind_dict[val] = counter 292 | end_points[i] = counter 293 | merged_node_ids.append(G.node_ids[val]) 294 | counter += 1 295 | else: 296 | end_points[i] = ind_dict[val] 297 | for val in base_node_indices: 298 | if val in ind_dict: 299 | indices_in_merged.append(ind_dict[val]) 300 | else: 301 | indices_in_merged.append(counter) 302 | merged_node_ids.append(G.node_ids[val]) 303 | counter += 1 304 | end_points = np.array(end_points, dtype=np.int32) 305 | indptr = np.array(indptr, dtype=np.int32) 306 | merged_node_ids = np.array(merged_node_ids, dtype=np.int32) 307 | indices_in_merged = np.array(indices_in_merged) 308 | return end_points, indptr, merged_node_ids, indices_in_merged 309 | 310 | def layer_sample_and_merge(self, G, sel_indices, depth): 311 | return _graph_sampler.random_sel_neighbor_and_merge( 312 | G.end_points, G.ind_ptr, G.node_ids, sel_indices.astype(np.int32), 313 | -1, -1.0, 1, 1) 314 | 315 | 316 | class NoMergeSampler(FixedNeighborSampler): 317 | def layer_sample_and_merge(self, G, sel_indices, depth): 318 | raise NotImplementedError 319 | 320 | 321 | 322 | def parse_hierarchy_sampler_from_desc(desc): 323 | assert len(desc) == 2 324 | name = desc[0] 325 | args = desc[1] 326 | if name.lower() == "fraction".lower(): 327 | logging.info("FractionNeighborSampler: layer_num=%d, fraction=%s" % (len(args), str(args))) 328 | return FractionNeighborSampler(layer_num=len(args), neighbor_fraction=args) 329 | elif name.lower() == "fixed".lower(): 330 | logging.info("FixedNeighborSampler: layer_num=%d, sample_num=%s" % (len(args), str(args))) 331 | return FixedNeighborSampler(layer_num=len(args), neighbor_num=args) 332 | elif name.lower() == "all".lower(): 333 | #assert len(args) == 1 334 | logging.info("AllNeighborSampler: layer_num=%d" % args) 335 | return AllNeighborSampler(layer_num=int(args)) 336 | else: 337 | raise NotImplementedError("name={name} is not supported!".format(name=name)) 338 | -------------------------------------------------------------------------------- /experiments/static_graph/sup_train_sample.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import mxnet as mx 4 | from mxgraph.graph import * 5 | from mxgraph.iterators import StaticGraphIterator 6 | from mxgraph.layers import * 7 | from mxgraph.utils import * 8 | from mxgraph.helpers.email_sender import send_msg 9 | from mxgraph.config import cfg, cfg_from_file, save_cfg_dir, ordered_dump 10 | from mxgraph.helpers.metric_logger import MetricLogger 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='Run the supervised training experiments with sampling.') 14 | parser.add_argument('--cfg', dest='cfg_file', help='Optional configuration file', default=None, type=str) 15 | parser.add_argument('--ctx', dest='ctx', default='gpu', 16 | help='Running Context. E.g `--ctx gpu` or `--ctx gpu0,gpu1` or `--ctx cpu`', type=str) 17 | parser.add_argument('--test', dest='test', help="Whether to run in the test mode", action="store_true") 18 | parser.add_argument('--silent', dest='silent', action='store_true') 19 | parser.add_argument('--output_inner_result', dest='output_inner_result', action='store_true') 20 | parser.add_argument('--save_epoch_interval', dest='save_epoch_interval', 21 | help="Epoch interval to output the inner result", default=20, type=int) 22 | parser.add_argument('--load_dir', dest='load_dir', help="The directory to load the pretrained model", 23 | default=None, type=str) 24 | parser.add_argument('--load_iter', dest='load_iter', help="The iteration to load", default=None, type=int) 25 | parser.add_argument('--save_dir', help='The saving directory', default=None, type=str) 26 | parser.add_argument('--emails', dest='emails', type=str, default="", help='Email addresses') 27 | args = parser.parse_args() 28 | args.ctx = parse_ctx(args.ctx) 29 | 30 | if args.cfg_file is not None: 31 | cfg_from_file(args.cfg_file, target=cfg) 32 | ### configure save_fir to save all the info 33 | if args.save_dir is None: 34 | if args.cfg_file is None: 35 | raise ValueError("Must set --cfg if not set --save_dir") 36 | args.save_dir = os.path.splitext(args.cfg_file)[0] 37 | if not os.path.isdir(args.save_dir): 38 | os.makedirs(args.save_dir) 39 | return args 40 | 41 | 42 | def eval_classification(net, loss_function, data_iterator, num_class, mode): 43 | """Evaluate the classification accuracy 44 | 45 | Parameters 46 | ---------- 47 | net : GraphMultiLayerHierarchicalNodes 48 | loss_function : mx.gluon.loss.Loss 49 | data_iterator : StaticGraphIterator 50 | num_class : int 51 | mode: str 52 | "valid" or "test" 53 | 54 | Returns 55 | ------- 56 | avg_loss : float 57 | f1 : float 58 | accuracy : float 59 | """ 60 | assert mode in ['valid', 'test'] 61 | 62 | labels_all = None 63 | node_ids_all = None 64 | preds_all = None 65 | total_loss = 0.0 66 | instance_num = 0 67 | data_iterator.begin_epoch(mode=mode) 68 | 69 | while not data_iterator.epoch_finished: 70 | layer0_features_nd, end_points_l, indptr_l, indices_in_merged_l, labels_nd, node_ids_l = \ 71 | data_iterator.sample() 72 | if net._output_inner_result: 73 | logits, gate_l, sharpness_l, attend_weights_wo_gate_l = \ 74 | net(layer0_features_nd, end_points_l, indptr_l, indices_in_merged_l) 75 | np.save(os.path.join(args.save_dir, 76 | 'inner_results%d' % args.save_id, 77 | mode+'_gate%d_1.npy' % iter), gate_l[0].asnumpy()) 78 | # print("gate_1", gate_l[0].asnumpy().shape, gate_l[0].asnumpy()) 79 | np.save(os.path.join(args.save_dir, 80 | 'inner_results%d' % args.save_id, 81 | mode+'_gate%d_2.npy' % iter), gate_l[1].asnumpy()) 82 | # print("gate_2", gate_l[1].asnumpy().shape, gate_l[1].asnumpy()) 83 | np.save(os.path.join(args.save_dir, 84 | 'inner_results%d' % args.save_id, 85 | mode+'_node_id%d_1.npy' % iter), node_ids_l[1]) 86 | # print("node_id_1", node_ids_l[1].shape, node_ids_l[1]) 87 | np.save(os.path.join(args.save_dir, 88 | 'inner_results%d' % args.save_id, 89 | mode+'_node_id%d_2.npy' % iter), node_ids_l[2]) 90 | 91 | else: 92 | logits = net(layer0_features_nd, end_points_l, indptr_l, indices_in_merged_l) 93 | 94 | total_loss += nd.sum(loss_function(logits, labels_nd)).asscalar() 95 | instance_num += labels_nd.shape[0] 96 | if cfg.DATA_NAME == 'ppi': 97 | iter_preds = (logits > 0) 98 | iter_labels = labels_nd 99 | else: 100 | iter_preds = nd.argmax(logits, axis=1) 101 | iter_labels = labels_nd.reshape((-1,)) 102 | # print(list(zip(iter_preds.tolist(), iter_labels.ravel().tolist()))) 103 | if preds_all is None: 104 | preds_all = iter_preds 105 | labels_all = iter_labels 106 | node_ids_all = node_ids_l[-1] 107 | else: 108 | preds_all = nd.concatenate([preds_all, iter_preds], axis=0) 109 | labels_all = nd.concatenate([labels_all, iter_labels], axis=0) 110 | ### node_ids is numpy array 111 | node_ids_all = np.concatenate([node_ids_all, node_ids_l[-1]], axis=0) 112 | avg_loss = total_loss / instance_num 113 | 114 | if cfg.DATA_NAME == 'ppi': 115 | num_class = 2 116 | f1 = nd_f1(pred=preds_all, label=labels_all, num_class=num_class, average="micro") 117 | acc = nd_acc(pred=preds_all, label=labels_all) 118 | 119 | return avg_loss, f1, acc 120 | 121 | 122 | def build(args): 123 | ctx = args.ctx[0] 124 | local_cfg = cfg.STATIC_GRAPH 125 | ### initialize data_iterator 126 | data_iterator = StaticGraphIterator(hierarchy_sampler_desc=local_cfg.MODEL.TRAIN.GRAPH_SAMPLER_ARGS, 127 | ctx=ctx, 128 | supervised=True, 129 | batch_node_num=local_cfg.MODEL.TRAIN.BATCH_SIZE, 130 | normalize_feature=local_cfg.MODEL.FEATURE_NORMALIZE, 131 | batch_sample_method="uniform") 132 | data_iterator.summary() 133 | 134 | ### build net 135 | net = GraphMultiLayerHierarchicalNodes(out_units=data_iterator.num_class, 136 | aggregator_args_list=local_cfg.MODEL.AGGREGATOR_ARGS_LIST, 137 | dropout_rate_list=local_cfg.MODEL.DROPOUT_RATE_LIST, 138 | dense_connect=local_cfg.MODEL.DENSE_CONNECT, 139 | l2_normalization=local_cfg.MODEL.L2_NORMALIZATION, 140 | first_embed_units=local_cfg.MODEL.FIRST_EMBED_UNITS, 141 | output_inner_result=args.output_inner_result, 142 | prefix='net_') 143 | net.hybridize() 144 | 145 | ### define loss_function 146 | if cfg.DATA_NAME == 'ppi': 147 | loss_function = gluon.loss.LogisticLoss(label_format='binary') 148 | else: 149 | loss_function = gluon.loss.SoftmaxCrossEntropyLoss(from_logits=False) 150 | loss_function.hybridize() 151 | 152 | return net, loss_function, data_iterator 153 | 154 | 155 | def train(args, net, loss_function, data_iterator): 156 | """Train the model 157 | """ 158 | ctx = args.ctx[0] 159 | local_cfg = cfg.STATIC_GRAPH 160 | 161 | net.initialize(init=mx.init.Xavier(magnitude=3), ctx=ctx) 162 | trainer = gluon.Trainer(net.collect_params(), 163 | local_cfg.MODEL.TRAIN.OPTIMIZER, 164 | {'learning_rate': local_cfg.MODEL.TRAIN.LR, 165 | 'wd': local_cfg.MODEL.TRAIN.WD}) 166 | 167 | train_loss_logger = MetricLogger(["iter", "loss"], ["%d", "%.4f"], 168 | os.path.join(args.save_dir, 'train_loss%d.csv' % args.save_id)) 169 | valid_loss_logger = MetricLogger(["iter", "loss", "f1", "acc", "is_best"], 170 | ["%d", "%.4f", "%.4f", "%.4f", "%d"], 171 | os.path.join(args.save_dir, 'valid_loss%d.csv' % args.save_id)) 172 | test_loss_logger = MetricLogger(["iter", "loss", "f1", "acc"], ["%d", "%.4f", "%.4f", "%.4f"], 173 | os.path.join(args.save_dir, 'test_loss%d.csv' % args.save_id)) 174 | 175 | best_valid_f1 = 0 176 | best_valid_iter_info = [] 177 | best_test_iter_info = [] 178 | no_better_valid = 0 179 | iter_id = 1 180 | epoch = 1 181 | train_moving_avg_loss = 0.0 182 | data_iterator.begin_epoch('train') 183 | for iter in range(1, local_cfg.MODEL.TRAIN.MAX_ITER): 184 | if data_iterator.epoch_finished: 185 | print("Epoch %d finished! It has %d iterations." % (epoch, iter_id)) 186 | data_iterator.begin_epoch('train') 187 | iter_id = 1 188 | epoch += 1 189 | else: 190 | iter_id += 1 191 | 192 | layer0_features_nd, end_points_l, indptr_l, indices_in_merged_l, labels_nd, node_ids_l =\ 193 | data_iterator.sample() 194 | # print("layer0_features_nd", layer0_features_nd.shape, "\n", layer0_features_nd) 195 | # print("end_points_l", len(end_points_l), end_points_l) 196 | # print("indptr_l", len(indptr_l), indptr_l) 197 | # print("indices_in_merged_l", len(indices_in_merged_l), indices_in_merged_l) 198 | # print("labels_nd", labels_nd.shape, "\n", labels_nd) 199 | #print("node_id", node_ids_l[0].shape, node_ids_l[1].shape,node_ids_l[2].shape) 200 | with mx.autograd.record(): 201 | if net._output_inner_result: 202 | logits, gate_l, sharpness_l, attend_weights_wo_gate_l =\ 203 | net(layer0_features_nd, end_points_l, indptr_l, indices_in_merged_l) 204 | # print("gate", len(gate_l), gate_l[0].shape, gate_l[1].shape) 205 | # print(gate_l[0]) 206 | # print(gate_l[1]) 207 | # if epoch % args.save_epoch_interval == 1 or epoch < args.save_epoch_interval + 1: 208 | # #temp_dict = dict([('gate%d' % i, gate.asnumpy()) 209 | # # for i, gate in enumerate(gate_l)] ) 210 | # np.save(os.path.join(args.save_dir, 211 | # 'inner_results%d' % args.save_id, 212 | # 'train_gate%d_1.npy' % epoch), gate_l[0].asnumpy()) 213 | # #print("gate_1", gate_l[0].asnumpy().shape, gate_l[0].asnumpy()) 214 | # np.save(os.path.join(args.save_dir, 215 | # 'inner_results%d' % args.save_id, 216 | # 'train_gate%d_2.npy' % epoch), gate_l[1].asnumpy()) 217 | # #print("gate_2", gate_l[1].asnumpy().shape, gate_l[1].asnumpy()) 218 | # np.save(os.path.join(args.save_dir, 219 | # 'inner_results%d' % args.save_id, 220 | # 'train_node_id%d_1.npy' % epoch), node_ids_l[1]) 221 | # #print("node_id_1", node_ids_l[1].shape, node_ids_l[1]) 222 | # np.save(os.path.join(args.save_dir, 223 | # 'inner_results%d' % args.save_id, 224 | # 'train_node_id%d_2.npy' % epoch), node_ids_l[2]) 225 | #print("node_id_2", node_ids_l[2].shape, node_ids_l[2]) 226 | # temp_dict = dict([('gate%d' % i, gate.asnumpy()) 227 | # for i, gate in enumerate(gate_l)] + 228 | # [('attend_weights_wo_gate%d' % i, ele) 229 | # for i, ele in enumerate(attend_weights_wo_gate_l)]) 230 | # np.savez(os.path.join(args.save_dir, 231 | # 'inner_results%d' % args.save_id, 232 | # 'gate_attweight%d.npz' % epoch), **temp_dict) 233 | else: 234 | logits = net(layer0_features_nd, end_points_l, indptr_l, indices_in_merged_l) 235 | loss = loss_function(logits, labels_nd) 236 | loss = nd.mean(loss) 237 | loss.backward() 238 | if iter == 1: 239 | logging.info("Total Param Number: %d" % gluon_total_param_num(net)) 240 | gluon_log_net_info(net, save_path=os.path.join(args.save_dir, 'net_info%d.txt' % args.save_id)) 241 | ### norm clipping 242 | if local_cfg.MODEL.TRAIN.GRAD_CLIP <= 0: 243 | gnorm = get_global_norm([v.grad() for v in net.collect_params().values()]) 244 | else: 245 | gnorm = gluon.utils.clip_global_norm([v.grad() for v in net.collect_params().values()], 246 | max_norm=local_cfg.MODEL.TRAIN.GRAD_CLIP) 247 | trainer.step(batch_size=1) 248 | iter_train_loss = loss.asscalar() 249 | train_moving_avg_loss += iter_train_loss 250 | logging.info('[iter=%d]: loss=%.4f, gnorm=%g' % (iter, iter_train_loss, gnorm)) 251 | train_loss_logger.log(iter=iter, loss=iter_train_loss) 252 | 253 | if iter % local_cfg.MODEL.TRAIN.VALID_ITER == 0: 254 | valid_loss, valid_f1, valid_accuracy = \ 255 | eval_classification(net=net, 256 | loss_function=loss_function, 257 | data_iterator=data_iterator, 258 | num_class=data_iterator.num_class, 259 | mode="valid") 260 | logging.info("Iter %d, Epoch %d,: train_moving_loss=%.4f, valid loss=%.4f, f1=%.4f, accuracy=%.4f" 261 | % (iter, epoch, train_moving_avg_loss / local_cfg.MODEL.TRAIN.VALID_ITER, 262 | valid_loss, valid_f1, valid_accuracy)) 263 | 264 | train_moving_avg_loss = 0.0 265 | if valid_f1 > best_valid_f1: 266 | logging.info("======================> Best Iter") 267 | is_best = True 268 | best_valid_f1 = valid_f1 269 | best_iter = iter 270 | best_valid_iter_info = [best_iter, valid_loss, valid_f1, valid_accuracy] 271 | no_better_valid = 0 272 | net.save_params( 273 | filename=os.path.join(args.save_dir, 'best_valid%d.params' % args.save_id)) 274 | # Calculate the test loss 275 | test_loss, test_f1, test_accuracy = \ 276 | eval_classification(net=net, 277 | loss_function=loss_function, 278 | data_iterator=data_iterator, 279 | num_class=data_iterator.num_class, 280 | mode="test") 281 | test_loss_logger.log(iter=iter, loss=test_loss, f1=test_f1, acc=test_accuracy) 282 | best_test_iter_info = [best_iter, test_loss, test_f1, test_accuracy] 283 | logging.info("Iter %d, Epoch %d: test loss=%.4f, f1=%.4f, accuracy=%.4f" % 284 | (iter, epoch, test_loss, test_f1, test_accuracy)) 285 | else: 286 | is_best = False 287 | no_better_valid += 1 288 | if no_better_valid > local_cfg.MODEL.TRAIN.EARLY_STOPPING_PATIENCE: 289 | # Finish training 290 | logging.info("Early stopping threshold reached. Stop training.") 291 | valid_loss_logger.log(iter=iter, loss=valid_loss, f1=valid_f1, 292 | acc=valid_accuracy, is_best=is_best) 293 | break 294 | ### add learning rate decay 295 | elif no_better_valid > local_cfg.MODEL.TRAIN.DECAY_PATIENCE: 296 | new_lr = max(trainer.learning_rate * local_cfg.MODEL.TRAIN.LR_DECAY_FACTOR, 297 | local_cfg.MODEL.TRAIN.MIN_LR) 298 | if new_lr < trainer.learning_rate: 299 | logging.info("Change the LR to %g" % new_lr) 300 | trainer.set_learning_rate(new_lr) 301 | no_better_valid = 0 302 | valid_loss_logger.log(iter=iter, loss=valid_loss, f1=valid_f1, acc=valid_accuracy, is_best=is_best) 303 | ### save best iter info 304 | logging.info("Best Valid: [Iter, Loss, F1, ACC] = %s" % str(best_valid_iter_info)) 305 | logging.info("Best Test : [Iter, Loss, F1, ACC] = %s" % str(best_test_iter_info)) 306 | valid_loss_logger.log(iter=best_valid_iter_info[0], 307 | loss=best_valid_iter_info[1], 308 | f1=best_valid_iter_info[2], 309 | acc=best_valid_iter_info[3], 310 | is_best=True) 311 | test_loss_logger.log(iter=best_test_iter_info[0], 312 | loss=best_test_iter_info[1], 313 | f1=best_test_iter_info[2], 314 | acc=best_test_iter_info[3]) 315 | if args.emails is not None and len(args.emails) > 0: 316 | for email_address in args.emails.split(','): 317 | send_msg(title=os.path.basename(args.save_dir), 318 | text="Test: [Iter, Loss, F1, ACC] = %s\n" % str(best_test_iter_info) 319 | + "Valid: [Iter, Loss, F1, ACC] = %s\n" % str(best_valid_iter_info) 320 | + 'Save Dir: %s\n' % args.save_dir 321 | + '\nConfig:\n' + ordered_dump(), 322 | dst_address=email_address) 323 | return 324 | 325 | def test(args, net, loss_function, data_iterator, save_id): 326 | 327 | net.load_params(os.path.join(args.save_dir, 'best_valid%d.params' % save_id), ctx=args.ctx[0]) 328 | test_loss, test_f1, test_accuracy = \ 329 | eval_classification(net=net, 330 | loss_function=loss_function, 331 | data_iterator=data_iterator, 332 | num_class=data_iterator.num_class, 333 | mode="test") 334 | logging.info("Test loss=%.4f, f1=%.4f, accuracy=%.4f" % (test_loss, test_f1, test_accuracy)) 335 | 336 | 337 | if __name__ == "__main__": 338 | args = parse_args() 339 | 340 | local_cfg = cfg.copy() 341 | del local_cfg.SPATIOTEMPORAL_GRAPH 342 | del local_cfg['SPATIOTEMPORAL_GRAPH'] 343 | 344 | args.save_id = save_cfg_dir(args.save_dir, source=local_cfg) 345 | if args.output_inner_result: 346 | if not os.path.isabs(os.path.join(args.save_dir, "inner_results%d" % args.save_id)): 347 | os.makedirs(os.path.join(args.save_dir, "inner_results%d" % args.save_id)) 348 | 349 | logging_config(folder=args.save_dir, name='sup_train_sample%d' % args.save_id, no_console=args.silent) 350 | logging.info(args) 351 | 352 | np.random.seed(cfg.NPY_SEED) 353 | mx.random.seed(cfg.MX_SEED) 354 | from mxgraph.graph import set_seed 355 | set_seed(cfg.MX_SEED) 356 | 357 | net, loss_function, data_iterator = build(args) 358 | 359 | if args.test: 360 | test(args, net, loss_function, data_iterator, 0) 361 | else: 362 | train(args, net, loss_function, data_iterator) -------------------------------------------------------------------------------- /GraphSampler/py_ext.cpp: -------------------------------------------------------------------------------- 1 | #include "Python.h" 2 | #include "numpy/arrayobject.h" 3 | #include "graph_sampler.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #if PY_MAJOR_VERSION >= 3 10 | static PyObject *GraphSamplerError; 11 | static graph_sampler::GraphSampler handle; 12 | 13 | #define CHECK_SEQUENCE(obj) \ 14 | { \ 15 | if (!PySequence_Check(obj)) { \ 16 | PyErr_SetString(GraphSamplerError, "Need a sequence!"); \ 17 | return NULL; \ 18 | } \ 19 | } 20 | 21 | #define PY_CHECK_EQUAL(a, b) \ 22 | { \ 23 | if ((a) != (b)) { \ 24 | std::ostringstream err_msg; \ 25 | err_msg << "Line:" << __LINE__ << ", Check\"" << #a << " == " << #b << "\" failed"; \ 26 | PyErr_SetString(GraphSamplerError, err_msg.str().c_str()); \ 27 | return NULL; \ 28 | } \ 29 | } \ 30 | 31 | 32 | void alloc_npy_from_ptr(const int* arr_ptr, const size_t arr_size, PyObject** arr_obj) { 33 | npy_intp siz[] = { static_cast(arr_size) }; 34 | *arr_obj = PyArray_EMPTY(1, siz, NPY_INT, 0); 35 | memcpy(PyArray_DATA(*arr_obj), static_cast(arr_ptr), sizeof(int) * arr_size); 36 | return; 37 | } 38 | 39 | void alloc_npy_from_ptr(const float* arr_ptr, const size_t arr_size, PyObject** arr_obj) { 40 | npy_intp siz[] = { static_cast(arr_size) }; 41 | *arr_obj = PyArray_EMPTY(1, siz, NPY_FLOAT32, 0); 42 | memcpy(PyArray_DATA(*arr_obj), static_cast(arr_ptr), sizeof(float) * arr_size); 43 | return; 44 | } 45 | 46 | template 47 | void alloc_npy_from_vector(const std::vector &arr_vec, PyObject** arr_obj) { 48 | alloc_npy_from_ptr(arr_vec.data(), arr_vec.size(), arr_obj); 49 | return; 50 | } 51 | 52 | 53 | /* 54 | Inputs: 55 | NDArray(int) src_end_points, 56 | NDArray(int) src_ind_ptr, 57 | NDArray(int) src_node_ids, 58 | int src_undirected, 59 | int initial_node, 60 | int walk_length, 61 | double return_prob, 62 | int max_node_num, 63 | int max_edge_num 64 | --------------------------------- 65 | Outputs: 66 | NDArray(int) subgraph_end_points, 67 | NDArray(int) subgraph_ind_ptr, 68 | NDArray(int) subgraph_node_ids 69 | */ 70 | static PyObject* random_walk(PyObject* self, PyObject* args) { 71 | PyObject* src_end_points; 72 | PyObject* src_ind_ptr; 73 | PyObject* src_node_ids; 74 | int src_undirected; 75 | int initial_node; 76 | int walk_length; 77 | double return_prob; 78 | int max_node_num; 79 | int max_edge_num; 80 | if (!PyArg_ParseTuple(args, "OOOiiidii", 81 | &src_end_points, 82 | &src_ind_ptr, 83 | &src_node_ids, 84 | &src_undirected, 85 | &initial_node, 86 | &walk_length, 87 | &return_prob, 88 | &max_node_num, 89 | &max_edge_num)) return NULL; 90 | PY_CHECK_EQUAL(PyArray_TYPE(src_end_points), NPY_INT32); 91 | PY_CHECK_EQUAL(PyArray_TYPE(src_ind_ptr), NPY_INT32); 92 | PY_CHECK_EQUAL(PyArray_TYPE(src_node_ids), NPY_INT32); 93 | 94 | int src_node_num = PyArray_SIZE(src_ind_ptr) - 1; 95 | int src_edge_num = PyArray_SIZE(src_end_points); 96 | if (src_undirected) { 97 | src_edge_num /= 2; 98 | } 99 | std::vector subgraph_end_points_vec; 100 | std::vector subgraph_ind_ptr_vec; 101 | std::vector subgraph_node_ids_vec; 102 | graph_sampler::SimpleGraph* subgraph = handle.random_walk(static_cast(PyArray_DATA(src_end_points)), 103 | static_cast(PyArray_DATA(src_ind_ptr)), 104 | static_cast(PyArray_DATA(src_node_ids)), 105 | src_undirected, 106 | src_node_num, 107 | initial_node, 108 | walk_length, 109 | return_prob, 110 | max_node_num, 111 | max_edge_num); 112 | subgraph->convert_to_csr(&subgraph_end_points_vec, &subgraph_ind_ptr_vec, &subgraph_node_ids_vec, 113 | static_cast(PyArray_DATA(src_node_ids)), src_node_num); 114 | PyObject* subgraph_end_points = NULL; 115 | PyObject* subgraph_ind_ptr = NULL; 116 | PyObject* subgraph_node_ids = NULL; 117 | alloc_npy_from_vector(subgraph_end_points_vec, &subgraph_end_points); 118 | alloc_npy_from_vector(subgraph_ind_ptr_vec, &subgraph_ind_ptr); 119 | alloc_npy_from_vector(subgraph_node_ids_vec, &subgraph_node_ids); 120 | delete subgraph; 121 | return Py_BuildValue("(NNN)", subgraph_end_points, subgraph_ind_ptr, subgraph_node_ids); 122 | } 123 | 124 | /* 125 | Inputs: 126 | NDArray(int) src_end_points, 127 | NDArray(int) src_indptr, 128 | NDArray(int) node_indices, 129 | float neg_sample_scale, 130 | int replace 131 | --------------------------------- 132 | Outputs: 133 | NDArray(int) end_points, 134 | NDArray(int) indptr, 135 | NDArray(int) edge_label 136 | NDArray(int) edge_count 137 | */ 138 | static PyObject* uniform_neg_sampling(PyObject* self, PyObject* args) { 139 | PyObject* src_end_points; 140 | PyObject* src_ind_ptr; 141 | PyObject* node_indices; 142 | float neg_sample_scale; 143 | int replace; 144 | if (!PyArg_ParseTuple(args, "OOOfi", 145 | &src_end_points, 146 | &src_ind_ptr, 147 | &node_indices, 148 | &neg_sample_scale, 149 | &replace)) return NULL; 150 | // Check Type 151 | PY_CHECK_EQUAL(PyArray_TYPE(src_end_points), NPY_INT32); 152 | PY_CHECK_EQUAL(PyArray_TYPE(src_ind_ptr), NPY_INT32); 153 | PY_CHECK_EQUAL(PyArray_TYPE(node_indices), NPY_INT32); 154 | 155 | int src_node_num = PyArray_SIZE(src_ind_ptr) - 1; 156 | int src_nnz = PyArray_SIZE(src_end_points); 157 | int dst_node_num = PyArray_SIZE(node_indices); 158 | int* dst_end_points_d = NULL; 159 | int* dst_indptr_d = NULL; 160 | int* dst_edge_label_d = NULL; 161 | int* dst_edge_count_d = NULL; 162 | int dst_nnz = 0; 163 | handle.uniform_neg_sampling(static_cast(PyArray_DATA(src_end_points)), 164 | static_cast(PyArray_DATA(src_ind_ptr)), 165 | static_cast(PyArray_DATA(node_indices)), 166 | src_nnz, 167 | src_node_num, 168 | dst_node_num, 169 | neg_sample_scale, 170 | replace, 171 | &dst_end_points_d, 172 | &dst_indptr_d, 173 | &dst_edge_label_d, 174 | &dst_edge_count_d, 175 | &dst_nnz); 176 | PyObject* dst_end_points = NULL; 177 | PyObject* dst_indptr = NULL; 178 | PyObject* dst_edge_label = NULL; 179 | PyObject* dst_edge_count = NULL; 180 | alloc_npy_from_ptr(dst_end_points_d, dst_nnz, &dst_end_points); 181 | alloc_npy_from_ptr(dst_indptr_d, dst_node_num + 1, &dst_indptr); 182 | alloc_npy_from_ptr(dst_edge_label_d, dst_nnz, &dst_edge_label); 183 | alloc_npy_from_ptr(dst_edge_count_d, dst_nnz, &dst_edge_count); 184 | delete[] dst_end_points_d; 185 | delete[] dst_indptr_d; 186 | delete[] dst_edge_label_d; 187 | delete[] dst_edge_count_d; 188 | return Py_BuildValue("(NNNN)", dst_end_points, dst_indptr, dst_edge_label, dst_edge_count); 189 | } 190 | 191 | /* 192 | Inputs: 193 | NDArray(int) src_end_points, 194 | NDArray(int) src_indptr, 195 | NDArray(int) src_node_ids, 196 | NDArray(int) sel_indices, 197 | int neighbor_num, 198 | float neighbor_frac, 199 | int sample_all, 200 | int replace 201 | --------------------------------- 202 | Outputs: 203 | NDArray(int) dst_end_points, 204 | NDArray(int) dst_ind_ptr, 205 | NDArray(int) merged_node_ids 206 | NDArray(int) indices_in_merged 207 | */ 208 | static PyObject* random_sel_neighbor_and_merge(PyObject* self, PyObject* args) { 209 | PyObject* src_end_points; 210 | PyObject* src_ind_ptr; 211 | PyObject* src_node_ids; 212 | PyObject* sel_indices; 213 | int neighbor_num; 214 | float neighbor_frac; 215 | int sample_all; 216 | int replace; 217 | if (!PyArg_ParseTuple(args, "OOOOifii", 218 | &src_end_points, 219 | &src_ind_ptr, 220 | &src_node_ids, 221 | &sel_indices, 222 | &neighbor_num, 223 | &neighbor_frac, 224 | &sample_all, 225 | &replace)) return NULL; 226 | // Check Type 227 | PY_CHECK_EQUAL(PyArray_TYPE(src_end_points), NPY_INT32); 228 | PY_CHECK_EQUAL(PyArray_TYPE(src_ind_ptr), NPY_INT32); 229 | PY_CHECK_EQUAL(PyArray_TYPE(src_node_ids), NPY_INT32); 230 | PY_CHECK_EQUAL(PyArray_TYPE(sel_indices), NPY_INT32); 231 | 232 | int src_node_num = PyArray_SIZE(src_ind_ptr) - 1; 233 | int src_nnz = PyArray_SIZE(src_end_points); 234 | int sel_node_num = PyArray_SIZE(sel_indices); 235 | std::vector dst_end_points_vec; 236 | std::vector dst_ind_ptr_vec; 237 | std::vector merged_node_ids_vec; 238 | std::vector indices_in_merged_vec; 239 | handle.random_sel_neighbor_and_merge(static_cast(PyArray_DATA(src_end_points)), 240 | static_cast(PyArray_DATA(src_ind_ptr)), 241 | static_cast(PyArray_DATA(src_node_ids)), 242 | static_cast(PyArray_DATA(sel_indices)), 243 | src_nnz, 244 | sel_node_num, 245 | neighbor_num, 246 | neighbor_frac, 247 | sample_all, 248 | replace, 249 | &dst_end_points_vec, 250 | &dst_ind_ptr_vec, 251 | &merged_node_ids_vec, 252 | &indices_in_merged_vec); 253 | PyObject* dst_end_points = NULL; 254 | PyObject* dst_ind_ptr = NULL; 255 | PyObject* merged_node_ids = NULL; 256 | PyObject* indices_in_merged = NULL; 257 | alloc_npy_from_vector(dst_end_points_vec, &dst_end_points); 258 | alloc_npy_from_vector(dst_ind_ptr_vec, &dst_ind_ptr); 259 | alloc_npy_from_vector(merged_node_ids_vec, &merged_node_ids); 260 | alloc_npy_from_vector(indices_in_merged_vec, &indices_in_merged); 261 | return Py_BuildValue("(NNNN)", dst_end_points, dst_ind_ptr, merged_node_ids, indices_in_merged); 262 | } 263 | 264 | /* 265 | Inputs: 266 | NDArray(int) src_end_points, 267 | NDArray(int) src_indptr, 268 | int initial_node, 269 | int max_node_num, 270 | int walk_length 271 | --------------------------------- 272 | Outputs: 273 | NDArray(int) dst_indices 274 | */ 275 | static PyObject* get_random_walk_nodes(PyObject* self, PyObject* args) { 276 | PyObject* src_end_points; 277 | PyObject* src_ind_ptr; 278 | int initial_node; 279 | int max_node_num; 280 | int walk_length; 281 | if (!PyArg_ParseTuple(args, "OOiii", 282 | &src_end_points, 283 | &src_ind_ptr, 284 | &initial_node, 285 | &max_node_num, 286 | &walk_length)) return NULL; 287 | // Check Type 288 | PY_CHECK_EQUAL(PyArray_TYPE(src_end_points), NPY_INT32); 289 | PY_CHECK_EQUAL(PyArray_TYPE(src_ind_ptr), NPY_INT32); 290 | int src_node_num = PyArray_SIZE(src_ind_ptr) - 1; 291 | int nnz = PyArray_SIZE(src_end_points); 292 | std::vector dst_indices_vec; 293 | handle.get_random_walk_nodes(static_cast(PyArray_DATA(src_end_points)), 294 | static_cast(PyArray_DATA(src_ind_ptr)), 295 | nnz, 296 | src_node_num, 297 | initial_node, 298 | max_node_num, 299 | walk_length, 300 | &dst_indices_vec); 301 | PyObject* dst_indices = NULL; 302 | alloc_npy_from_vector(dst_indices_vec, &dst_indices); 303 | return Py_BuildValue("N", dst_indices); 304 | } 305 | 306 | /* 307 | Inputs: 308 | NDArray(int) seed 309 | --------------------------------- 310 | Outputs: 311 | NDArray(int) ret_val 312 | */ 313 | static PyObject* set_seed(PyObject* self, PyObject* args) { 314 | int seed; 315 | if (!PyArg_ParseTuple(args, "i", &seed)) return NULL; 316 | handle.set_seed(seed); 317 | return Py_BuildValue("i", 1); 318 | } 319 | 320 | 321 | /* 322 | Inputs: 323 | NDArray(int) src_end_points 324 | NDArray(int) or None src_values 325 | NDArray(int) src_ind_ptr 326 | NDArray(int) src_row_ids 327 | NDArray(int) src_col_ids 328 | NDArray(int) sel_row_indices 329 | NDArray(int) sel_col_indices 330 | --------------------------------- 331 | Outputs: 332 | NDArray(int) dst_end_points 333 | NDArray(int) dst_ind_ptr 334 | NDArray(int) dst_row_ids 335 | NDArray(int) dst_col_ids 336 | int dst_row_num 337 | int dst_col_num 338 | int dst_nnz 339 | */ 340 | static PyObject* csr_submat(PyObject* self, PyObject* args) { 341 | PyObject* src_end_points; 342 | PyObject* src_values; 343 | PyObject* src_ind_ptr; 344 | PyObject* src_row_ids; 345 | PyObject* src_col_ids; 346 | PyObject* sel_row_indices; 347 | PyObject* sel_col_indices; 348 | if (!PyArg_ParseTuple(args, "OOOOOOO", 349 | &src_end_points, 350 | &src_values, 351 | &src_ind_ptr, 352 | &src_row_ids, 353 | &src_col_ids, 354 | &sel_row_indices, 355 | &sel_col_indices)) return NULL; 356 | // Check Type 357 | PY_CHECK_EQUAL(PyArray_TYPE(src_end_points), NPY_INT32); 358 | if(src_values != Py_None) { 359 | PY_CHECK_EQUAL(PyArray_TYPE(src_values), NPY_FLOAT32); 360 | } 361 | PY_CHECK_EQUAL(PyArray_TYPE(src_ind_ptr), NPY_INT32); 362 | PY_CHECK_EQUAL(PyArray_TYPE(src_row_ids), NPY_INT32); 363 | PY_CHECK_EQUAL(PyArray_TYPE(src_col_ids), NPY_INT32); 364 | if(sel_row_indices != Py_None) { 365 | PY_CHECK_EQUAL(PyArray_TYPE(sel_row_indices), NPY_INT32); 366 | } 367 | if(sel_col_indices != Py_None) { 368 | PY_CHECK_EQUAL(PyArray_TYPE(sel_col_indices), NPY_INT32); 369 | } 370 | 371 | 372 | long long src_row_num = PyArray_SIZE(src_row_ids); 373 | long long src_col_num = PyArray_SIZE(src_col_ids); 374 | long long src_nnz = PyArray_SIZE(src_end_points); 375 | ASSERT(src_row_num <= std::numeric_limits::max()); 376 | ASSERT(src_col_num <= std::numeric_limits::max()); 377 | ASSERT(src_nnz < std::numeric_limits::max()); 378 | int dst_row_num = (sel_row_indices == Py_None) ? src_row_num : PyArray_SIZE(sel_row_indices); 379 | int dst_col_num = (sel_col_indices == Py_None) ? src_col_num : PyArray_SIZE(sel_col_indices); 380 | float* src_values_ptr = (src_values == Py_None) ? nullptr : static_cast(PyArray_DATA(src_values)); 381 | int* sel_row_indices_ptr = (sel_row_indices == Py_None) ? nullptr : static_cast(PyArray_DATA(sel_row_indices)); 382 | int* sel_col_indices_ptr = (sel_col_indices == Py_None) ? nullptr : static_cast(PyArray_DATA(sel_col_indices)); 383 | int* dst_end_points_d = NULL; 384 | float* dst_values_d = NULL; 385 | int* dst_ind_ptr_d = NULL; 386 | int* dst_row_ids_d = NULL; 387 | int* dst_col_ids_d = NULL; 388 | int dst_nnz; 389 | graph_sampler::slice_csr_mat(static_cast(PyArray_DATA(src_end_points)), 390 | src_values_ptr, 391 | static_cast(PyArray_DATA(src_ind_ptr)), 392 | static_cast(PyArray_DATA(src_row_ids)), 393 | static_cast(PyArray_DATA(src_col_ids)), 394 | src_row_num, 395 | src_col_num, 396 | src_nnz, 397 | sel_row_indices_ptr, 398 | sel_col_indices_ptr, 399 | dst_row_num, 400 | dst_col_num, 401 | &dst_end_points_d, 402 | &dst_values_d, 403 | &dst_ind_ptr_d, 404 | &dst_row_ids_d, 405 | &dst_col_ids_d, 406 | &dst_nnz); 407 | PyObject* dst_end_points = NULL; 408 | PyObject* dst_values = NULL; 409 | PyObject* dst_ind_ptr = NULL; 410 | PyObject* dst_row_ids = NULL; 411 | PyObject* dst_col_ids = NULL; 412 | alloc_npy_from_ptr(dst_end_points_d, dst_nnz, &dst_end_points); 413 | if(dst_values_d == nullptr) { 414 | Py_INCREF(Py_None); 415 | dst_values = Py_None; 416 | } else { 417 | alloc_npy_from_ptr(dst_values_d, dst_nnz, &dst_values); 418 | } 419 | alloc_npy_from_ptr(dst_ind_ptr_d, dst_row_num + 1, &dst_ind_ptr); 420 | alloc_npy_from_ptr(dst_row_ids_d, dst_row_num, &dst_row_ids); 421 | alloc_npy_from_ptr(dst_col_ids_d, dst_col_num, &dst_col_ids); 422 | 423 | //Clear Allocated Variables 424 | delete[] dst_end_points_d; 425 | if (dst_values_d != nullptr) delete[] dst_values_d; 426 | delete[] dst_ind_ptr_d; 427 | delete[] dst_row_ids_d; 428 | delete[] dst_col_ids_d; 429 | return Py_BuildValue("(NNNNN)", dst_end_points, dst_values, dst_ind_ptr, dst_row_ids, dst_col_ids); 430 | } 431 | 432 | 433 | static PyMethodDef myextension_methods[] = { 434 | {"random_walk", (PyCFunction)random_walk, METH_VARARGS, NULL}, 435 | {"uniform_neg_sampling", (PyCFunction)uniform_neg_sampling, METH_VARARGS, NULL}, 436 | {"random_sel_neighbor_and_merge", (PyCFunction)random_sel_neighbor_and_merge, METH_VARARGS, NULL}, 437 | {"get_random_walk_nodes", (PyCFunction)get_random_walk_nodes, METH_VARARGS, NULL}, 438 | {"set_seed", (PyCFunction)set_seed, METH_VARARGS, NULL}, 439 | {"csr_submat", (PyCFunction)csr_submat, METH_VARARGS, NULL}, 440 | {NULL, NULL} 441 | }; 442 | 443 | static struct PyModuleDef moduledef = { 444 | PyModuleDef_HEAD_INIT, 445 | "_graph_sampler", 446 | NULL, 447 | -1, 448 | myextension_methods 449 | }; 450 | 451 | 452 | PyMODINIT_FUNC 453 | PyInit__graph_sampler(void) 454 | { 455 | PyObject *m = PyModule_Create(&moduledef); 456 | if (m == NULL) 457 | return NULL; 458 | import_array(); 459 | GraphSamplerError = PyErr_NewException("graph_sampler.error", NULL, NULL); 460 | Py_INCREF(GraphSamplerError); 461 | PyModule_AddObject(m, "graph_sampler.error", GraphSamplerError); 462 | return m; 463 | } 464 | #endif 465 | 466 | -------------------------------------------------------------------------------- /mxgraph/layers/graph_rnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from collections import namedtuple 4 | import mxnet as mx 5 | import numpy as np 6 | import mxnet.ndarray as nd 7 | import mxnet.gluon as gluon 8 | import copy 9 | from mxnet.gluon import nn, HybridBlock 10 | from mxgraph.config import cfg 11 | from mxgraph.utils import safe_eval 12 | from mxgraph.layers import MuGGA 13 | from mxgraph.layers.common import * 14 | from mxgraph.layers.aggregators import parse_aggregator_from_desc, get_activation 15 | 16 | 17 | def shortcut_aggregator_forward(aggregator, X, Z, end_points, indptr, edge_data): 18 | if isinstance(aggregator, MuGGA): 19 | data, gate, sharpness, attend_weights_wo_gate =\ 20 | aggregator(X, Z, end_points, indptr, edge_data) 21 | mid_info = [gate, sharpness, attend_weights_wo_gate] 22 | else: 23 | data = aggregator(X, Z, end_points, indptr, edge_data) 24 | mid_info = [] 25 | return data, mid_info 26 | 27 | 28 | class GraphConvGRUCell(HybridBlock): 29 | def __init__(self, prefix=None, params=None): 30 | super(GraphConvGRUCell, self).__init__(prefix=prefix, params=params) 31 | 32 | 33 | class GraphRNNCell(HybridBlock): 34 | def __init__(self, aggregator_args, aggregation_type='all', typ='rnn', prefix=None, params=None): 35 | super(GraphRNNCell, self).__init__(prefix=prefix, params=params) 36 | assert len(aggregator_args) == 2 37 | self._aggregator_args = copy.deepcopy(aggregator_args) 38 | # In order to compute gates in GRU, we need to scale up the number of out_units 39 | self._act = get_activation(cfg.AGGREGATOR.ACTIVATION) 40 | self._state_dim = aggregator_args[1][0] 41 | self._typ = typ.lower() 42 | assert self._typ in ['rnn', 'lstm'] 43 | self._map_dim = 4 * self._state_dim if self._typ == 'lstm' else self._state_dim 44 | self._aggregation_type = aggregation_type 45 | with self.name_scope(): 46 | if self._aggregation_type == 'all' or self._aggregation_type == 'in'\ 47 | or self._aggregation_type == 'state_only': 48 | self._aggregator_args[1][0] = self._map_dim 49 | self.aggregator = parse_aggregator_from_desc( 50 | aggregator_desc=[self._aggregator_args[0], 51 | self._aggregator_args[1], 52 | 'agg_', 'identity']) 53 | if self._aggregation_type == 'in' or self._aggregation_type == 'no_agg'\ 54 | or self._aggregation_type == 'state_only': 55 | self.direct_map = nn.Dense(units=self._map_dim, flatten=False, prefix='direct_') 56 | 57 | def summary(self): 58 | logging.info("Graph %s: State Dim=%d, Type=%s, Aggregator=%s" 59 | % (self._typ.upper(), self._state_dim, 60 | self._aggregation_type, str(self._aggregator_args))) 61 | 62 | def hybrid_forward(self, F, data, states, 63 | end_points, indptr, edge_data=None): 64 | """ 65 | 66 | Parameters 67 | ---------- 68 | F 69 | data : Symbol or NDArray 70 | Shape (batch_size, node_num, feat_dim) 71 | states : list Symbol or NDArray 72 | Shape (batch_size, node_num, state_dim) 73 | endpoints : Symbol or NDArray 74 | Shape (nnz_x2h, ) 75 | indptr : Symbol or NDArray 76 | Shape (node_num + 1, ) 77 | edge_data : Symbol or NDArray or None 78 | Shape (nnz_x2h, ) 79 | Returns 80 | ------- 81 | next_hidden : Symbol or NDArray 82 | Shape (batch_size, node_num, feat_dim) 83 | mid_info : None or list 84 | """ 85 | if states is None: 86 | prev_h = F.broadcast_axis(F.slice_axis(F.zeros_like(data), begin=0, end=1, axis=2), 87 | axis=2, size=self._state_dim) 88 | if self._typ == 'lstm': 89 | prev_c = prev_h 90 | else: 91 | if self._typ == "lstm": 92 | prev_h, prev_c = states 93 | else: 94 | prev_h = states[0] 95 | if self._aggregation_type == 'all': 96 | concat_data = F.concat(data, prev_h, dim=-1) 97 | h_data, mid_info = shortcut_aggregator_forward(self.aggregator, concat_data, 98 | concat_data, end_points, 99 | indptr, edge_data) 100 | elif self._aggregation_type == 'in': 101 | h_data, mid_info = shortcut_aggregator_forward(self.aggregator, data, 102 | data, end_points, 103 | indptr, edge_data) 104 | h_data = h_data + self.direct_map(prev_h) 105 | elif self._aggregation_type == 'state_only': 106 | h_data, mid_info = shortcut_aggregator_forward(self.aggregator, data, 107 | prev_h, end_points, 108 | indptr, edge_data) 109 | h_data = h_data + self.direct_map(prev_h) 110 | elif self._aggregation_type == 'no_agg': 111 | h_data = self.direct_map(F.concat(data, prev_h, dim=-1)) 112 | mid_info = [] 113 | else: 114 | raise NotImplementedError 115 | if self._typ == 'lstm': 116 | h_data_l = F.split(h_data, num_outputs=4, axis=2) 117 | F_t = F.Activation(h_data_l[0], act_type="sigmoid") # Forget Gate 118 | I_t = self._act(h_data_l[1]) # Input Gate 119 | O_t = F.Activation(h_data_l[2], act_type="sigmoid") # Output Gate 120 | H_t = F.Activation(h_data_l[3], act_type="sigmoid") # Info 121 | new_c = F_t * prev_c + I_t * H_t 122 | new_h = O_t * self._act(new_c) 123 | return [new_h, new_c], mid_info 124 | else: 125 | new_h = self._act(h_data) 126 | return [new_h], mid_info 127 | 128 | 129 | class GraphGRUCell(HybridBlock): 130 | def __init__(self, aggregator_args, aggregation_type='all', typ='rnn', prefix=None, params=None): 131 | super(GraphGRUCell, self).__init__(prefix=prefix, params=params) 132 | assert len(aggregator_args) == 2 133 | self._aggregator_args = copy.deepcopy(aggregator_args) 134 | # In order to compute gates in GRU, we need to scale up the number of out_units 135 | self._act = get_activation(cfg.AGGREGATOR.ACTIVATION) 136 | self._state_dim = aggregator_args[1][0] 137 | self._typ = "gru" 138 | self._map_dim = 3 * self._state_dim 139 | self._aggregation_type = aggregation_type 140 | with self.name_scope(): 141 | self._aggregator_args[1][0] = self._map_dim 142 | self.aggregator_x2h = parse_aggregator_from_desc( 143 | aggregator_desc=[self._aggregator_args[0], 144 | self._aggregator_args[1], 145 | 'agg_x2h_', 'identity']) 146 | self.aggregator_h2h = parse_aggregator_from_desc( 147 | aggregator_desc=[self._aggregator_args[0], 148 | self._aggregator_args[1], 149 | 'agg_h2h_', 'identity']) 150 | if self._aggregation_type != 'concat': 151 | self.direct_h2h = nn.Dense(units=self._map_dim, flatten=False, prefix='direct_h2h_') 152 | 153 | def summary(self): 154 | logging.info("Graph %s: State Dim=%d, Aggregation Type=%s, Aggregator=%s" 155 | % (self._typ.upper(), self._state_dim, self._aggregation_type, str(self._aggregator_args))) 156 | 157 | def hybrid_forward(self, F, data, states, 158 | end_points, indptr, edge_data=None): 159 | if states is None: 160 | prev_h = F.broadcast_axis(F.slice_axis(F.zeros_like(data), begin=0, end=1, axis=2), 161 | axis=2, size=self._state_dim) 162 | else: 163 | prev_h = states[0] 164 | x2h_data, x2h_mid_info = shortcut_aggregator_forward(self.aggregator_x2h, data, 165 | data, end_points, 166 | indptr, edge_data) 167 | if self._aggregation_type != 'concat': 168 | h2h_data, h2h_mid_info = shortcut_aggregator_forward(self.aggregator_h2h, data, 169 | prev_h, end_points, indptr, edge_data) 170 | h2h_data = h2h_data + self.direct_h2h(h2h_data) 171 | else: 172 | h2h_data, h2h_mid_info = shortcut_aggregator_forward(self.aggregator_h2h, 173 | F.concat(data, prev_h, dim=-1), 174 | prev_h, end_points, indptr, edge_data) 175 | mid_info = x2h_mid_info + h2h_mid_info 176 | 177 | x2h_data_l = F.split(x2h_data, num_outputs=3, axis=2) 178 | h2h_data_l = F.split(h2h_data, num_outputs=3, axis=2) 179 | U_t = F.Activation(x2h_data_l[0] + h2h_data_l[0], act_type='sigmoid') 180 | R_t = F.Activation(x2h_data_l[1] + h2h_data_l[1], act_type='sigmoid') 181 | H_prime_t = self._act(x2h_data_l[2] + R_t * h2h_data_l[2]) 182 | H_t = (1 - U_t) * H_prime_t + U_t * prev_h 183 | return [H_t], mid_info 184 | 185 | 186 | class StackGraphRNN(HybridBlock): 187 | def __init__(self, out_units, aggregator_args_list, aggregation_type, rnn_type, dropout, 188 | in_length, out_length, prefix=None, params=None): 189 | super(StackGraphRNN, self).__init__(prefix=prefix, params=params) 190 | self._in_length = in_length 191 | self._out_length = out_length 192 | self._layer_num = len(aggregator_args_list) 193 | self._out_units = out_units 194 | self._dropout = dropout 195 | self._aggregation_type = aggregation_type 196 | self._rnn_type = rnn_type 197 | self._act = get_activation(cfg.AGGREGATOR.ACTIVATION) 198 | with self.name_scope(): 199 | self.enc_pre_embed = nn.HybridSequential(prefix='enc_pre_embed_') 200 | self.dec_pre_embed = nn.HybridSequential(prefix='dec_pre_embed_') 201 | with self.enc_pre_embed.name_scope(): 202 | self.enc_pre_embed.add(nn.Dense(units=16, flatten=False)) 203 | self.enc_pre_embed.add(self._act) 204 | with self.dec_pre_embed.name_scope(): 205 | self.dec_pre_embed.add(nn.Dense(units=16, flatten=False)) 206 | self.dec_pre_embed.add(self._act) 207 | self.dropout_layer = nn.Dropout(rate=dropout) 208 | self.encoder_graph_rnn_cells = nn.HybridSequential() 209 | for i, layer_args in enumerate(aggregator_args_list): 210 | self.encoder_graph_rnn_cells.add(GraphGRUCell(aggregator_args=layer_args, 211 | aggregation_type=aggregation_type, 212 | typ=self._rnn_type, 213 | prefix="enc_graph_rnn%d_"%i)) 214 | self.decoder_graph_rnn_cells = nn.HybridSequential() 215 | for i, layer_args in enumerate(aggregator_args_list): 216 | self.decoder_graph_rnn_cells.add(GraphGRUCell(aggregator_args=layer_args, 217 | aggregation_type=aggregation_type, 218 | typ=self._rnn_type, 219 | prefix="dec_graph_rnn%d_"%i)) 220 | self.out_layer = nn.Dense(units=self._out_units, flatten=False, 221 | prefix="out_") 222 | 223 | def summary(self): 224 | logging.info("Stack Graph RNN: in_length=%d, out_length=%d" 225 | % (self._in_length, self._out_length)) 226 | logging.info("Encoder:") 227 | for i in range(len(self.encoder_graph_rnn_cells)): 228 | self.encoder_graph_rnn_cells[i].summary() 229 | logging.info("Decoder:") 230 | for i in range(len(self.decoder_graph_rnn_cells)): 231 | self.decoder_graph_rnn_cells[i].summary() 232 | 233 | def hybrid_forward(self, F, data_in, data_out, gt_prob, out_additional_feature, 234 | end_points, indptr, edge_data=None): 235 | """ 236 | 237 | Parameters 238 | ---------- 239 | F 240 | data_in : 241 | Shape (in_length, batch_size, node_num, feat_dim) 242 | Will be normalized!!! 243 | data_out: 244 | Shape (out_length, batch_size, node_num, feat_dim) 245 | Will be normalized!!! 246 | gt_prob: 247 | Shape (1,) 248 | The probability to use the groundtruth 249 | out_additional_feature: 250 | Shape (out_length, batch_size, node_num, time_feat_dim) 251 | endpoints 252 | Shape (nnz_x2h,) 253 | indptr 254 | Shape (node_num + 1,) 255 | edge_data 256 | Shape (nnz_x2h, edge_feat_dim) 257 | Returns 258 | ------- 259 | pred: 260 | Shape (out_length, batch_size, node_num, out_dim) 261 | enc_mid_info: list 262 | dec_mid_info: list 263 | """ 264 | states = [None for _ in range(self._layer_num)] 265 | data_in = self.enc_pre_embed(data_in) 266 | data_in_l = F.split(data_in, num_outputs=self._in_length, axis=0, squeeze_axis=True) 267 | data_out_l = F.split(data_out, num_outputs=self._out_length, axis=0, squeeze_axis=True) 268 | out_additional_feature_l = F.split(out_additional_feature, num_outputs=self._out_length, 269 | axis=0, squeeze_axis=True) 270 | enc_mid_info = [[] for _ in range(self._layer_num)] 271 | dec_mid_info = [[] for _ in range(self._layer_num)] 272 | pred_l = [] 273 | curr_in = None 274 | for j in range(self._in_length): 275 | curr_in = data_in_l[j] 276 | for i in range(self._layer_num): 277 | new_states, mid_info = self.encoder_graph_rnn_cells[i](curr_in, states[i], 278 | end_points, indptr, edge_data) 279 | states[i] = new_states 280 | enc_mid_info[i].extend(mid_info) 281 | curr_in = self.dropout_layer(new_states[0]) 282 | pred_l.append(self.out_layer(curr_in)) 283 | # Begin forecaster 284 | for j in range(self._out_length - 1): 285 | use_gt = F.random_uniform(0, 1) < gt_prob 286 | data_in_after_ss = F.broadcast_mul(use_gt, data_out_l[j])\ 287 | + F.broadcast_mul(1 - use_gt, pred_l[-1]) 288 | curr_in = F.concat(data_in_after_ss, out_additional_feature_l[j], dim=-1) 289 | curr_in = self.dec_pre_embed(curr_in) 290 | for i in range(self._layer_num): 291 | new_states, mid_info = self.decoder_graph_rnn_cells[i](curr_in, states[i], 292 | end_points, indptr, edge_data) 293 | states[i] = new_states 294 | dec_mid_info[i].extend(mid_info) 295 | curr_in = self.dropout_layer(new_states[0]) 296 | pred_l.append(self.out_layer(curr_in)) 297 | pred = F.concat(*[F.expand_dims(ele, axis=0) for ele in pred_l], dim=0) 298 | enc_gates = [] 299 | enc_sharpness = [] 300 | enc_attention_weights = [] 301 | dec_gates = [] 302 | dec_sharpness = [] 303 | dec_attention_weights = [] 304 | for i in range(self._layer_num): 305 | if len(enc_mid_info[i]) > 0: 306 | all_gates = F.concat(*[F.expand_dims(ele, 0) for ele in enc_mid_info[i][::3]], 307 | dim=0) 308 | all_sharpness = F.concat(*[F.expand_dims(ele, 0) for ele in enc_mid_info[i][1::3]], 309 | dim=0) 310 | all_attention_weights = F.concat( 311 | *[F.expand_dims(ele, 0) for ele in enc_mid_info[i][2::3]], 312 | dim=0) 313 | enc_gates.append(all_gates) 314 | enc_sharpness.append(all_sharpness) 315 | enc_attention_weights.append(all_attention_weights) 316 | if len(dec_mid_info[i]) > 0: 317 | all_gates = F.concat(*[F.expand_dims(ele, 0) for ele in dec_mid_info[i][::3]], 318 | dim=0) 319 | all_sharpness = F.concat(*[F.expand_dims(ele, 0) for ele in dec_mid_info[i][1::3]], 320 | dim=0) 321 | all_attention_weights = F.concat( 322 | *[F.expand_dims(ele, 0) for ele in dec_mid_info[i][2::3]], dim=0) 323 | dec_gates.append(all_gates) 324 | dec_sharpness.append(all_sharpness) 325 | dec_attention_weights.append(all_attention_weights) 326 | return pred, enc_gates, enc_sharpness, enc_attention_weights, \ 327 | dec_gates, dec_sharpness, dec_attention_weights 328 | 329 | 330 | # class StackGraphRNN2(StackGraphRNN): 331 | # def hybrid_forward(self, F, data_in, out_additional_feature, 332 | # end_points, indptr, edge_data=None): 333 | # """ 334 | # 335 | # Parameters 336 | # ---------- 337 | # F 338 | # data_in : 339 | # Shape (in_length, batch_size, node_num, feat_dim) 340 | # Will be normalized!!! 341 | # out_additional_feature: 342 | # Shape (out_length, batch_size, node_num, time_feat_dim) 343 | # endpoints 344 | # Shape (nnz_x2h,) 345 | # indptr 346 | # Shape (node_num + 1,) 347 | # edge_data 348 | # Shape (nnz_x2h, edge_feat_dim) 349 | # Returns 350 | # ------- 351 | # pred: 352 | # Shape (out_length, batch_size, node_num, out_dim) 353 | # enc_mid_info: list 354 | # dec_mid_info: list 355 | # """ 356 | # states = [None for _ in range(self._layer_num)] 357 | # data_in = self.enc_pre_embed(data_in) 358 | # data_in_l = F.split(data_in, num_outputs=self._in_length, axis=0, squeeze_axis=True) 359 | # out_additional_feature_l = F.split(out_additional_feature, num_outputs=self._out_length, 360 | # axis=0, squeeze_axis=True) 361 | # enc_mid_info = [[] for _ in range(self._layer_num)] 362 | # dec_mid_info = [[] for _ in range(self._layer_num)] 363 | # pred_l = [] 364 | # for j in range(self._in_length): 365 | # curr_in = data_in_l[j] 366 | # for i in range(self._layer_num): 367 | # new_states, mid_info = self.encoder_graph_rnn_cells[i](curr_in, states[i], 368 | # end_points, indptr, 369 | # edge_data) 370 | # states[i] = new_states 371 | # enc_mid_info[i].extend(mid_info) 372 | # curr_in = self.dropout_layer(new_states[0]) 373 | # # Reverse the states 374 | # for j in range(self._out_length): 375 | # curr_in = self.dec_pre_embed(out_additional_feature_l[j]) 376 | # for i in range(self._layer_num - 1, -1, -1): 377 | # new_states, mid_info = self.decoder_graph_rnn_cells[i](curr_in, states[i], 378 | # end_points, indptr, 379 | # edge_data) 380 | # states[i] = new_states 381 | # dec_mid_info[i].extend(mid_info) 382 | # curr_in = self.dropout_layer(new_states[0]) 383 | # pred_l.append(self.out_layer(curr_in)) 384 | # pred = F.concat(*[F.expand_dims(ele, axis=0) for ele in pred_l], dim=0) 385 | # enc_gates = [] 386 | # enc_sharpness = [] 387 | # enc_attention_weights = [] 388 | # dec_gates = [] 389 | # dec_sharpness = [] 390 | # dec_attention_weights = [] 391 | # for i in range(self._layer_num): 392 | # if len(enc_mid_info[i]) > 0: 393 | # all_gates = F.concat(*[F.expand_dims(ele, 0) for ele in enc_mid_info[i][::3]], 394 | # dim=0) 395 | # all_sharpness = F.concat(*[F.expand_dims(ele, 0) for ele in enc_mid_info[i][1::3]], 396 | # dim=0) 397 | # all_attention_weights = F.concat(*[F.expand_dims(ele, 0) for ele in enc_mid_info[i][2::3]], 398 | # dim=0) 399 | # enc_gates.append(all_gates) 400 | # enc_sharpness.append(all_sharpness) 401 | # enc_attention_weights.append(all_attention_weights) 402 | # if len(dec_mid_info[i]) > 0: 403 | # all_gates = F.concat(*[F.expand_dims(ele, 0) for ele in dec_mid_info[i][::3]], 404 | # dim=0) 405 | # all_sharpness = F.concat(*[F.expand_dims(ele, 0) for ele in dec_mid_info[i][1::3]], 406 | # dim=0) 407 | # all_attention_weights = F.concat( 408 | # *[F.expand_dims(ele, 0) for ele in dec_mid_info[i][2::3]], dim=0) 409 | # dec_gates.append(all_gates) 410 | # dec_sharpness.append(all_sharpness) 411 | # dec_attention_weights.append(all_attention_weights) 412 | # return pred, enc_gates, enc_sharpness, enc_attention_weights,\ 413 | # dec_gates, dec_sharpness, dec_attention_weights 414 | -------------------------------------------------------------------------------- /mxgraph/layers/stack_layers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from mxnet.gluon import Block 3 | from .aggregators import * 4 | 5 | 6 | class BaseGraphMultiLayer(HybridBlock): 7 | def __init__(self, out_units, aggregator_args_list, dropout_rate_list, graph_type="homo", in_units=None, first_embed_units=256, 8 | dense_connect=False, every_layer_l2_normalization=False, l2_normalization=False, 9 | output_inner_result=False, 10 | prefix=None, params=None): 11 | super(BaseGraphMultiLayer, self).__init__(prefix=prefix, params=params) 12 | self._aggregator_args_list = aggregator_args_list 13 | self._dropout_rate_list = dropout_rate_list 14 | self._dense_connect = dense_connect 15 | self._first_embed_units = first_embed_units 16 | self._every_layer_l2_normalization = every_layer_l2_normalization 17 | self._l2_normalization = l2_normalization 18 | self._graph_type = graph_type 19 | if self._l2_normalization: 20 | self._l2_normalization_layer = L2Normalization(axis=-1) 21 | self._output_inner_result = output_inner_result 22 | print("graph type", graph_type) 23 | with self.name_scope(): 24 | if graph_type == "homo": 25 | self.embed = nn.Dense(self._first_embed_units, flatten=False) 26 | self.aggregators = nn.HybridSequential() 27 | self.dropout_layers = nn.HybridSequential() 28 | for args, dropout_rate in zip(aggregator_args_list, dropout_rate_list): 29 | self.aggregators.add(parse_aggregator_from_desc(args)) 30 | self.dropout_layers.add(nn.Dropout(dropout_rate)) 31 | self.out_layer = nn.Dense(units=out_units, flatten=False) 32 | 33 | class GraphMultiLayerAllNodes(BaseGraphMultiLayer): 34 | """Generate the output for all nodes in the graph 35 | 36 | """ 37 | def hybrid_forward(self, F, data, end_points, indptr, edge_data=None): 38 | """ 39 | 40 | Parameters 41 | ---------- 42 | F 43 | data : Symbol or NDArray 44 | Shape: (node_num, feat_dim) 45 | end_points: Symbol or NDArray 46 | Shape: (nnz,) 47 | indptr: Symbol or NDArray 48 | Shape: (node_num,) 49 | edge_data: Symbol or NDArray or None 50 | Shape: (nnz, edge_dim) 51 | Returns 52 | ------- 53 | ret: Symbol or NDArray 54 | """ 55 | ### DO NOT ADD THE DEGREE FEATURE 56 | # degrees = F.slice_axis(indptr, axis=-1, begin=1, end=None) \ 57 | # - F.slice_axis(indptr, axis=-1, begin=0, end=-1) 58 | # log_degrees = F.log(degrees.astype("float32") + 1.0) 59 | # data = F.concat(data, F.expand_dims(log_degrees, axis=1), dim=1) 60 | data = self.embed(data) 61 | layer_in_l = [data] 62 | # Forward through the aggregators 63 | gate_l = [] 64 | sharpness_l = [] 65 | attend_weights_wo_gate_l = [] 66 | for aggregator, dropout_layer in zip(self.aggregators, self.dropout_layers): 67 | layer_in = layer_in_l[-1] 68 | # When the whole graph is used, neighbor_data will always be equal to data 69 | if isinstance(aggregator, MuGGA): 70 | out, gate, sharpness, attend_weights_wo_gate =\ 71 | aggregator(F.expand_dims(layer_in, axis=0), 72 | F.expand_dims(layer_in, axis=0), 73 | end_points, indptr, edge_data) 74 | gate_l.append(gate) 75 | sharpness_l.append(sharpness) 76 | attend_weights_wo_gate_l.append(attend_weights_wo_gate) 77 | else: 78 | out = aggregator(F.expand_dims(layer_in, axis=0), 79 | F.expand_dims(layer_in, axis=0), 80 | end_points, indptr, edge_data) 81 | out = F.reshape(out, shape=(0, 0), reverse=True) 82 | if self._every_layer_l2_normalization: 83 | out = self._l2_normalization_layer(out) 84 | else: 85 | out = dropout_layer(out) 86 | if self._dense_connect: 87 | # We are actually implementing skip connection 88 | out = F.concat(layer_in, out) 89 | layer_in_l.append(out) 90 | # Go through the output layer 91 | out = self.out_layer(layer_in_l[-1]) 92 | if self._l2_normalization: 93 | out = self._l2_normalization_layer(out) 94 | if self._output_inner_result: 95 | return out, gate_l, sharpness_l, attend_weights_wo_gate_l 96 | else: 97 | return out 98 | 99 | 100 | class GraphMultiLayerHierarchicalNodes(BaseGraphMultiLayer): 101 | """Generate the outputs for part of the nodes. 102 | 103 | The other nodes are appended to the list level by level (Level Set) 104 | 105 | """ 106 | 107 | def hybrid_forward(self, F, data, end_points_l, indptr_l, indices_in_merged_l=None, edge_data_l=None): 108 | """ 109 | We aggregate the lower layer information based on the end_points and indptr information 110 | 111 | Note that here the end_points, indptr, indices should from the highest layer to the lowest layer 112 | 113 | Remember the data should be from the 0th layer! 114 | 115 | Parameters 116 | ---------- 117 | F 118 | data : Symbol or NDArray 119 | The basic features (Features in Layer0) 120 | end_points_l: list or Symbol or NDArray 121 | Stores the edge connectivity within a layer 122 | Shape: [(NNZ_1,), (NNZ_2,), ...] 123 | indptr_l: list 124 | Pointers to the relative end_points 125 | Shape: [(N_1,), (N_2,), ...] 126 | indices_in_merged_l: list 127 | indices_in_merged_l[i] Stores the relative ith layer position of the nodes that appear in the i+1th layer. 128 | To be more specific, indices_in_merged_l[1][j] stores the indices of the jth node in Layer 1 w.r.t Layer 0 (Basic Features Layer) 129 | We can select the features from the previous layer by calling `features_i[indices_in_merged_l[i], :]` 130 | Shape: [(N_1,), (N_2,), ...] 131 | edge_data_l: list 132 | Edge feature corresponds to the end_points 133 | Shape: [(NNZ_1, fdim1), (NNZ_2, fdim2), ...] 134 | Returns 135 | ------- 136 | ret: Symbol or NDArray 137 | """ 138 | layer_num = len(self.aggregators) 139 | if indices_in_merged_l is not None: 140 | assert len(end_points_l) == layer_num 141 | assert len(indptr_l) == layer_num 142 | assert len(indices_in_merged_l) == layer_num 143 | data = self.embed(data) 144 | layer_in_l = [data] 145 | gate_l = [] 146 | sharpness_l = [] 147 | attend_weights_wo_gate_l = [] 148 | for i, (aggregator, dropout_layer) in enumerate(zip(self.aggregators, self.dropout_layers)): 149 | lower_layer_data = layer_in_l[-1] 150 | if indices_in_merged_l is None: 151 | end_points = end_points_l 152 | indptr = indptr_l 153 | upper_layer_base_data = lower_layer_data 154 | else: 155 | end_points = end_points_l[i] 156 | indptr = indptr_l[i] 157 | upper_layer_base_data = F.take(lower_layer_data, 158 | indices=indices_in_merged_l[i], axis=0) 159 | edge_data = None if edge_data_l is None else edge_data_l[i] 160 | if isinstance(aggregator, MuGGA): 161 | out, gate, sharpness, attend_weights_wo_gate =\ 162 | aggregator(F.expand_dims(upper_layer_base_data, axis=0), 163 | F.expand_dims(lower_layer_data, axis=0), 164 | end_points, 165 | indptr, 166 | edge_data) 167 | gate_l.append(gate) 168 | sharpness_l.append(sharpness) 169 | attend_weights_wo_gate_l.append(attend_weights_wo_gate) 170 | else: 171 | out = aggregator(F.expand_dims(upper_layer_base_data, axis=0), 172 | F.expand_dims(lower_layer_data, axis=0), 173 | end_points, 174 | indptr, 175 | edge_data) 176 | out = F.reshape(out, shape=(0, 0), reverse=True) 177 | if self._every_layer_l2_normalization: 178 | out = self._l2_normalization_layer(out) 179 | else: 180 | out = dropout_layer(out) 181 | if self._dense_connect: 182 | # We are actually implementing skip connection 183 | out = F.concat(lower_layer_data, out, dim=-1) 184 | layer_in_l.append(out) 185 | out = self.out_layer(layer_in_l[-1]) 186 | if self._l2_normalization: 187 | out = self._l2_normalization_layer(out) 188 | if self._output_inner_result: 189 | return out, gate_l, sharpness_l, attend_weights_wo_gate_l 190 | else: 191 | return out 192 | 193 | 194 | class HeterGraphMultiLayerHierarchicalNodes(BaseGraphMultiLayer): 195 | """Generate the outputs for part of the nodes. 196 | 197 | The other nodes are appended to the list level by level (Level Set) 198 | 199 | """ 200 | 201 | def hybrid_forward(self, F, data, mask, end_points_l, indptr_l, indices_in_merged_l=None, edge_data_l=None): 202 | """ 203 | We aggregate the lower layer information based on the end_points and indptr information 204 | 205 | Note that here the end_points, indptr, indices should from the highest layer to the lowest layer 206 | 207 | Remember the data should be from the 0th layer! 208 | 209 | Parameters 210 | ---------- 211 | F 212 | data : Symbol or NDArray Shape (num_node, fea_dim) 213 | The basic features (Features in Layer0) 214 | end_points_l: list or Symbol or NDArray 215 | Stores the edge connectivity within a layer 216 | Shape: [(NNZ_1,), (NNZ_2,), ...] 217 | indptr_l: list 218 | Pointers to the relative end_points 219 | Shape: [(N_1,), (N_2,), ...] 220 | indices_in_merged_l: list 221 | indices_in_merged_l[i] Stores the relative ith layer position of the nodes that appear in the i+1th layer. 222 | To be more specific, indices_in_merged_l[1][j] stores the indices of the jth node in Layer 1 w.r.t Layer 0 (Basic Features Layer) 223 | We can select the features from the previous layer by calling `features_i[indices_in_merged_l[i], :]` 224 | Shape: [(N_1,), (N_2,), ...] 225 | mask: Symbol or NDArray Shape (num_node, num_set, 1) 226 | The node set mask (nodes in Layer0) 227 | edge_data_l: list 228 | Edge feature corresponds to the end_points 229 | Shape: [(NNZ_1, fdim1), (NNZ_2, fdim2), ...] 230 | Returns 231 | ------- 232 | ret: Symbol or NDArray 233 | """ 234 | assert self._graph_type == "heter" 235 | layer_num = len(self.aggregators) 236 | if indices_in_merged_l is not None: 237 | assert len(end_points_l) == layer_num 238 | assert len(indptr_l) == layer_num 239 | assert len(indices_in_merged_l) == layer_num 240 | 241 | #data = self.embed(data) 242 | layer_in_l = [data] 243 | neighbor_mask_l = [mask] 244 | gate_l = [] 245 | sharpness_l = [] 246 | attend_weights_wo_gate_l = [] 247 | for i, (aggregator, dropout_layer) in enumerate(zip(self.aggregators, self.dropout_layers)): 248 | assert isinstance(aggregator, HeterGraphPoolAggregator) 249 | 250 | lower_layer_data = layer_in_l[-1] 251 | neighbor_mask = neighbor_mask_l[-1] 252 | if indices_in_merged_l is None: 253 | end_points = end_points_l 254 | indptr = indptr_l 255 | upper_layer_base_data = lower_layer_data 256 | node_mask = neighbor_mask 257 | else: 258 | end_points = end_points_l[i] 259 | indptr = indptr_l[i] 260 | upper_layer_base_data = F.take(lower_layer_data, 261 | indices=indices_in_merged_l[i], axis=0) 262 | node_mask = F.take(neighbor_mask, indices=indices_in_merged_l[i], axis=0) 263 | 264 | edge_data = None if edge_data_l is None else edge_data_l[i] 265 | 266 | out = aggregator(F.expand_dims(upper_layer_base_data, axis=0), 267 | F.expand_dims(lower_layer_data, axis=0), 268 | F.expand_dims(node_mask, axis=0), 269 | F.expand_dims(neighbor_mask, axis=0), 270 | end_points, 271 | indptr, 272 | F.expand_dims(edge_data, axis=0)) 273 | out = F.reshape(out, shape=(0, 0), reverse=True) 274 | if self._every_layer_l2_normalization: 275 | out = self._l2_normalization_layer(out) 276 | else: 277 | out = dropout_layer(out) 278 | if self._dense_connect: 279 | # We are actually implementing skip connection 280 | out = F.concat(lower_layer_data, out, dim=-1) 281 | layer_in_l.append(out) 282 | neighbor_mask_l.append(node_mask) 283 | out = self.out_layer(layer_in_l[-1]) 284 | if self._l2_normalization: 285 | out = self._l2_normalization_layer(out) 286 | 287 | return out 288 | 289 | 290 | class BaseHeterGraphMultiLayer(Block): 291 | def __init__(self, out_units, aggregator_args_list, dropout_rate_list, 292 | dense_connect=False, every_layer_l2_normalization=False, l2_normalization=False, 293 | output_inner_result=False, 294 | graph_type="bi", num_node_set=None, num_edge_set=None, 295 | prefix=None, params=None): 296 | super(BaseHeterGraphMultiLayer, self).__init__(prefix=prefix, params=params) 297 | self._aggregator_args_list = aggregator_args_list 298 | self._dropout_rate_list = dropout_rate_list 299 | self._dense_connect = dense_connect 300 | self._every_layer_l2_normalization = every_layer_l2_normalization 301 | self._l2_normalization = l2_normalization 302 | self._num_node_set = num_node_set 303 | self._num_edge_set = num_edge_set 304 | self._graph_type = graph_type 305 | if self._l2_normalization: 306 | self._l2_normalization_layer = L2Normalization(axis=-1) 307 | self._output_inner_result = output_inner_result 308 | with self.name_scope(): 309 | self.aggregators = nn.Sequential() 310 | self.dropout_layers = nn.Sequential() 311 | for args, dropout_rate in zip(aggregator_args_list, dropout_rate_list): 312 | self.aggregators.add(parse_aggregator_from_desc(args)) 313 | self.dropout_layers.add(nn.Dropout(dropout_rate)) 314 | self.out_layer = nn.Dense(units=out_units, flatten=False, prefix="out_") 315 | 316 | class BiGraphMultiLayerHierarchicalNodes(BaseHeterGraphMultiLayer): 317 | """Generate the outputs for part of the nodes. 318 | 319 | The other nodes are appended to the list level by level (Level Set) 320 | 321 | """ 322 | 323 | def forward(self, data, end_points_l, indptr_l, indices_in_merged_l=None, 324 | node_type_mask=None, edge_type_mask_l=None, seg_indices_l=None): 325 | """ 326 | We aggregate the lower layer information based on the end_points and indptr information 327 | 328 | Note that here the end_points, indptr, indices should from the highest layer to the lowest layer 329 | 330 | Remember the data should be from the 0th layer! 331 | 332 | Parameters 333 | ---------- 334 | F 335 | data : Symbol or NDArray Shape (num_node, fea_dim) 336 | The basic features (Features in Layer0) 337 | end_points_l: list or Symbol or NDArray 338 | Stores the edge connectivity within a layer 339 | Shape: [(NNZ_1,), (NNZ_2,), ...] 340 | indptr_l: list 341 | Pointers to the relative end_points 342 | Shape: [(N_1,), (N_2,), ...] 343 | indices_in_merged_l: list 344 | indices_in_merged_l[i] Stores the relative ith layer position of the nodes that appear in the i+1th layer. 345 | To be more specific, indices_in_merged_l[1][j] stores the indices of the jth node in Layer 1 w.r.t Layer 0 (Basic Features Layer) 346 | We can select the features from the previous layer by calling `features_i[indices_in_merged_l[i], :]` 347 | Shape: [(N_1,), (N_2,), ...] 348 | mask: Symbol or NDArray Shape (num_node, num_set, 1) 349 | The node set mask (nodes in Layer0) 350 | edge_data_l: list 351 | Edge feature corresponds to the end_points 352 | Shape: [(NNZ_1, fdim1), (NNZ_2, fdim2), ...] 353 | seg_indices_l: list 354 | the edge indices of the end_points, which is arange(end_points_l[i].size) 355 | Shape: [(NNZ_1, ), (NNZ_2, ), ...] 356 | Returns 357 | ------- 358 | ret: Symbol or NDArray 359 | """ 360 | assert self._graph_type == "bi" 361 | layer_num = len(self.aggregators) 362 | if indices_in_merged_l is not None: 363 | assert len(end_points_l) == layer_num 364 | assert len(indptr_l) == layer_num 365 | assert len(indices_in_merged_l) == layer_num 366 | 367 | if self._num_edge_set is not None: 368 | assert len(edge_type_mask_l) == layer_num 369 | assert len(seg_indices_l) == layer_num 370 | else: 371 | edge_type_mask_l = [None for _ in range(layer_num)] 372 | seg_indices_l = [None for _ in range(layer_num)] 373 | 374 | if self._num_node_set is not None: 375 | assert node_type_mask is not None 376 | neighbor_node_type_mask_l = [node_type_mask] 377 | for i in range(layer_num): 378 | if indices_in_merged_l is None: 379 | neighbor_node_type_mask_l.append(node_type_mask) 380 | else: 381 | neighbor_node_type_mask_l.append(nd.take(neighbor_node_type_mask_l[-1], 382 | indices=indices_in_merged_l[i], axis=0)) 383 | else: 384 | neighbor_node_type_mask_l = [None for _ in range(layer_num+1)] 385 | 386 | # data = self.embed(data) 387 | layer_in_l = [data] 388 | for i, (aggregator, dropout_layer) in enumerate(zip(self.aggregators, self.dropout_layers)): 389 | assert isinstance(aggregator, BiGraphPoolAggregator) 390 | lower_layer_data = layer_in_l[-1] 391 | if indices_in_merged_l is None: 392 | end_points = end_points_l 393 | indptr = indptr_l 394 | upper_layer_base_data = lower_layer_data 395 | else: 396 | end_points = end_points_l[i] 397 | indptr = indptr_l[i] 398 | upper_layer_base_data = nd.take(lower_layer_data, 399 | indices=indices_in_merged_l[i], axis=0) 400 | lower_layer_data = dropout_layer(lower_layer_data) 401 | 402 | out = aggregator(nd.expand_dims(upper_layer_base_data, axis=0), 403 | nd.expand_dims(lower_layer_data, axis=0), 404 | end_points, 405 | indptr, 406 | nd.expand_dims(neighbor_node_type_mask_l[i+1], axis=0), 407 | nd.expand_dims(neighbor_node_type_mask_l[i], axis=0), 408 | nd.expand_dims(edge_type_mask_l[i], axis=0), 409 | seg_indices_l[i]) 410 | out = nd.reshape(out, shape=(0, 0), reverse=True) 411 | if self._every_layer_l2_normalization: 412 | out = self._l2_normalization_layer(out) 413 | # else: 414 | # out = dropout_layer(out) 415 | if self._dense_connect: 416 | # We are actually implementing skip connection 417 | out = nd.concat(lower_layer_data, out, dim=-1) 418 | layer_in_l.append(out) 419 | out = self.out_layer(layer_in_l[-1]) 420 | if self._l2_normalization: 421 | out = self._l2_normalization_layer(out) 422 | return out 423 | 424 | 425 | class HeterEmbedLayer(Block): 426 | def __init__(self, embed_units, num_set, prefix=None, params=None): 427 | super(HeterEmbedLayer, self).__init__(prefix=prefix, params=params) 428 | self._num_set = num_set 429 | self._embed_units = embed_units 430 | with self.name_scope(): 431 | self.embed_layer = nn.Sequential() 432 | for i in range(self._num_set): 433 | self.embed_layer.add(nn.Dense(units=self._embed_units, flatten=False, use_bias=False)) 434 | 435 | def forward(self, sampled_node_order, layer0_features_nd_l): 436 | fea_embeds_l = [] 437 | for i in range(len(self.embed_layer)): 438 | fea_embeds_l.append(self.embed_layer[i](layer0_features_nd_l[i])) 439 | fea_embeds = mx.nd.concat(*fea_embeds_l, dim=0) 440 | data = mx.nd.take(fea_embeds, indices=sampled_node_order) 441 | return data 442 | 443 | class HeterPredLayer(HybridBlock): 444 | def __init__(self, input_node_dim, hidden_dim, out_units, num_pred_set, act="leaky", prefix=None, params=None): 445 | super(HeterPredLayer, self).__init__(prefix=prefix, params=params) 446 | self._input_node_dim = input_node_dim 447 | self._hidden_dim = hidden_dim 448 | self._out_units = out_units 449 | self._num_pred_set = num_pred_set 450 | with self.name_scope(): 451 | self._act = get_activation(act) 452 | self.hidden_layer = nn.Dense(units=self._hidden_dim, flatten=False) 453 | self.out_layer = nn.Dense(units=self._out_units, flatten=False) 454 | 455 | def hybrid_forward(self, F, data): 456 | data = F.reshape(F.swapaxes(F.reshape(data, 457 | shape=(self._num_pred_set, -1, self._input_node_dim)), 0, 1), 458 | shape=(-1, self._input_node_dim*self._num_pred_set)) 459 | hidden_1 = self._act(self.hidden_layer(data)) 460 | out = self.out_layer(hidden_1) 461 | return out 462 | 463 | 464 | 465 | class HeterDecoder(HybridBlock): 466 | def __init__(self, input_node_dim, num_pred_set, out_units, act="leaky", prefix=None, params=None): 467 | super(HeterDecoder, self).__init__(prefix=prefix, params=params) 468 | self._input_node_dim = input_node_dim 469 | self._num_pred_set = num_pred_set 470 | self._out_units = out_units 471 | with self.name_scope(): 472 | self._act = get_activation(act) 473 | self.bilinear_W = nn.HybridSequential() 474 | for i in range(self._out_units): 475 | self.bilinear_W.add(nn.Dense(units=self._input_node_dim, flatten=False, use_bias=False)) 476 | 477 | 478 | def hybrid_forward(self, F, data, rates): 479 | data = F.reshape(data, shape=(self._num_pred_set, -1, self._input_node_dim)) 480 | user_embeds = F.reshape(F.slice_axis(data, axis=0, begin=0, end=1), 481 | shape=(-1,self._input_node_dim)) 482 | item_embeds = F.reshape(F.slice_axis(data, axis=0, begin=1, end=2), 483 | shape=(-1,self._input_node_dim)) 484 | 485 | rate_weight_l = [] 486 | for i in range(self._out_units): 487 | weight = F.sum(F.broadcast_mul(self.bilinear_W[i](user_embeds), 488 | item_embeds), axis=1, keepdims=True) 489 | #print("weight", weight.shape) 490 | rate_weight_l.append(weight) 491 | rate_weights = F.softmax(F.concat(*rate_weight_l, dim=1)) 492 | #print("rate_weights", rate_weights.shape) 493 | out = F.sum(F.broadcast_mul(rate_weights, rates), axis=1) 494 | return out 495 | 496 | class BiDecoder(Block): 497 | def __init__(self, input_node_dim, num_pred_set, out_units, act="leaky", prefix=None, params=None): 498 | super(BiDecoder, self).__init__(prefix=prefix, params=params) 499 | self._input_node_dim = input_node_dim 500 | self._num_pred_set = num_pred_set 501 | self._out_units = out_units 502 | with self.name_scope(): 503 | # self._act = get_activation(act) 504 | self.bilinear_W = nn.Sequential() 505 | for i in range(self._out_units): 506 | self.bilinear_W.add(nn.Dense(units=self._input_node_dim, flatten=False, use_bias=False)) 507 | 508 | 509 | def forward(self, data): 510 | data = nd.reshape(data, shape=(self._num_pred_set, -1, self._input_node_dim)) 511 | user_embeds = nd.reshape(nd.slice_axis(data, axis=0, begin=0, end=1), # (user_num, input_node_dim) 512 | shape=(-1,self._input_node_dim)) 513 | item_embeds = nd.reshape(nd.slice_axis(data, axis=0, begin=1, end=2), # (user_num, input_node_dim) 514 | shape=(-1,self._input_node_dim)) 515 | 516 | rate_weight_l = [] 517 | for i in range(self._out_units): 518 | weight = nd.sum(nd.broadcast_mul(self.bilinear_W[i](user_embeds), # (user_num, input_node_dim) * 519 | item_embeds), axis=1, keepdims=True) 520 | #print("weight", weight.shape) 521 | rate_weight_l.append(weight) 522 | out = nd.concat(*rate_weight_l, dim=1) 523 | return out 524 | --------------------------------------------------------------------------------