├── .clang-format ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── gnnflow ├── __init__.py ├── cache │ ├── __init__.py │ ├── cache.py │ ├── fifo_cache.py │ ├── gnnlab_static_cache.py │ ├── lfu_cache.py │ └── lru_cache.py ├── config.py ├── csrc │ ├── api.cc │ ├── common.h │ ├── doubly_linked_list.cu │ ├── doubly_linked_list.h │ ├── dynamic_graph.cu │ ├── dynamic_graph.h │ ├── kvstore.cc │ ├── kvstore.h │ ├── logging.cc │ ├── logging.h │ ├── resource_holder.h │ ├── sampling_kernels.cu │ ├── sampling_kernels.h │ ├── temporal_block_allocator.cu │ ├── temporal_block_allocator.h │ ├── temporal_sampler.cu │ ├── temporal_sampler.h │ ├── utils.cu │ └── utils.h ├── data.py ├── distributed │ ├── __init__.py │ ├── common.py │ ├── dispatcher.py │ ├── dist_context.py │ ├── dist_graph.py │ ├── dist_sampler.py │ ├── graph_services.py │ ├── kvstore.py │ ├── partition.py │ └── utils.py ├── dynamic_graph.py ├── models │ ├── __init__.py │ ├── apan.py │ ├── dgnn.py │ ├── gat.py │ ├── graphsage.py │ ├── jodie.py │ └── modules │ │ ├── __init__.py │ │ ├── apan_memory.py │ │ ├── layers.py │ │ ├── memory.py │ │ └── memory_updater.py ├── temporal_sampler.py └── utils.py ├── requirements.txt ├── scripts ├── download_data.sh ├── offline_edge_prediction.py ├── offline_edge_prediction_pipethread.py ├── offline_edge_prediction_presample.py ├── pipeline.py ├── pipeline_distributed.py ├── run_offline.sh ├── run_offline_dist.sh └── train.py ├── setup.py ├── tests ├── test_build_graph.py ├── test_dataset.py ├── test_dynamic_graph.py ├── test_model.py ├── test_temporal_sampler.py └── utils.py └── tgl ├── CMakeLists.txt ├── offline_tgl_presample.py ├── run_tgl.sh ├── sampler_core.cpp ├── setup_tgl.py └── utils.py /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | --- 4 | Language: Cpp 5 | ColumnLimit: 80 6 | 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | 133 | # Prerequisites 134 | *.d 135 | 136 | # Compiled Object files 137 | *.slo 138 | *.lo 139 | *.o 140 | *.obj 141 | 142 | # Precompiled Headers 143 | *.gch 144 | *.pch 145 | 146 | # Compiled Dynamic libraries 147 | *.so 148 | *.dylib 149 | *.dll 150 | 151 | # Fortran module files 152 | *.mod 153 | *.smod 154 | 155 | # Compiled Static libraries 156 | *.lai 157 | *.la 158 | *.a 159 | *.lib 160 | 161 | # Executables 162 | *.exe 163 | *.out 164 | *.app 165 | 166 | # CMake 167 | CMakeLists.txt.user 168 | CMakeCache.txt 169 | CMakeFiles 170 | CMakeScripts 171 | Testing 172 | Makefile 173 | cmake_install.cmake 174 | install_manifest.txt 175 | compile_commands.json 176 | CTestTestfile.cmake 177 | _deps 178 | 179 | # IDE 180 | 181 | .idea 182 | .vscode 183 | data/ 184 | data 185 | .clangd/ 186 | compile_commands.json 187 | 188 | rmm_log.txt 189 | *.qdrep 190 | *.ncu-rep 191 | *.pt 192 | col*.txt 193 | row*.txt 194 | partition_table.txt 195 | *.bin 196 | *.png 197 | 198 | # MacOS 199 | .DS_Store 200 | 201 | *.npy -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/pybind11"] 2 | path = third_party/pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | [submodule "third_party/rmm"] 5 | path = third_party/rmm 6 | url = https://github.com/jasperzhong/rmm.git 7 | [submodule "third_party/spdlog"] 8 | path = third_party/spdlog 9 | url = https://github.com/gabime/spdlog.git 10 | [submodule "third_party/abseil-cpp"] 11 | path = third_party/abseil-cpp 12 | url = https://github.com/abseil/abseil-cpp 13 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18) 2 | project(libgnnflow CXX) 3 | enable_language(CUDA) 4 | 5 | set(TARGET_LIB "libgnnflow") 6 | 7 | set(CMAKE_CXX_STANDARD 14) 8 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 9 | set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) 10 | 11 | # set cuda arch 12 | set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) 13 | include(FindCUDA/select_compute_arch) 14 | CUDA_DETECT_INSTALLED_GPUS(INSTALLED_GPU_CCS_1) 15 | string(STRIP "${INSTALLED_GPU_CCS_1}" INSTALLED_GPU_CCS_2) 16 | string(REPLACE " " ";" INSTALLED_GPU_CCS_3 "${INSTALLED_GPU_CCS_2}") 17 | string(REPLACE "." "" CUDA_ARCH_LIST "${INSTALLED_GPU_CCS_3}") 18 | set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH_LIST}) 19 | message(STATUS "CUDA_ARCH_LIST: ${CUDA_ARCH_LIST}") 20 | 21 | # 3rd party 22 | find_package(PythonLibs REQUIRED) 23 | add_subdirectory(third_party/pybind11) 24 | 25 | 26 | include_directories(${PYTHON_INCLUDE_DIRS}) 27 | include_directories(${PROJECT_SOURCE_DIR}/gnnflow/csrc) 28 | include_directories(/usr/local/cuda/include) 29 | include_directories(third_party/pybind11/include) 30 | include_directories(third_party/spdlog/include) 31 | include_directories(third_party/rmm/include) 32 | 33 | 34 | file(GLOB_RECURSE GNNFLOW_SRC_FILES ${PROJECT_SOURCE_DIR}/gnnflow/csrc/*.cc) 35 | set_source_files_properties(${PROJECT_SOURCE_DIR}/gnnflow/csrc/api.cc PROPERTIES LANGUAGE CUDA) 36 | file(GLOB_RECURSE GNNFLOW_SRC_CUDA_FILES ${PROJECT_SOURCE_DIR}/gnnflow/csrc/*.cu) 37 | set_source_files_properties(${GNNFLOW_SRC_CUDA_FILES} PROPERTIES LANGUAGE CUDA) 38 | list(APPEND GNNFLOW_SRC_FILES ${GNNFLOW_SRC_CUDA_FILES}) 39 | 40 | pybind11_add_module(${TARGET_LIB} ${GNNFLOW_SRC_FILES}) 41 | 42 | add_subdirectory(third_party/abseil-cpp) 43 | target_link_libraries(${TARGET_LIB} PRIVATE absl::flat_hash_map) 44 | 45 | find_package(Torch REQUIRED) 46 | find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") 47 | target_link_libraries(${TARGET_LIB} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) 48 | 49 | target_compile_options(${TARGET_LIB} PRIVATE $<$: 50 | --generate-line-info 51 | --use_fast_math 52 | -rdc=true 53 | -fopenmp 54 | >) 55 | 56 | set_property(TARGET ${TARGET_LIB} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 57 | 58 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fopenmp -fPIC -Wall -ftree-vectorize") 59 | set(ARCH_FLAGS "-march=native -mtune=native") 60 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_FLAGS}") 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSPipe 2 | 3 | This repository is the official implementation of MSPipe: Efficient Temporal GNN Training via Staleness-aware Pipeline 4 | 5 | ## Install 6 | 7 | Our development environment: 8 | 9 | - Ubuntu 20.04LTS 10 | - g++ 9.4 11 | - CUDA 11.3 / 11.6 12 | - cmake 3.23 13 | 14 | Dependencies: 15 | 16 | - torch >= 1.10 17 | - dgl (CUDA version) 18 | 19 | Compile and install the MSPipe: 20 | 21 | ```sh 22 | git submodule update --init --recursive 23 | pip install -r requirements.txt 24 | python setup.py install 25 | ``` 26 | 27 | For debug mode, 28 | 29 | ```sh 30 | DEBUG=1 pip install -v -e . 31 | ``` 32 | 33 | Compile and install the TGL (presample version): 34 | 35 | ```sh 36 | cd tgl 37 | python setup_tgl.py build_ext --inplace 38 | ``` 39 | 40 | ## Prepare data 41 | 42 | ```sh 43 | cd scripts/ && ./download_data.sh 44 | ``` 45 | 46 | ## Train 47 | 48 | **MSPipe** 49 | 50 | Training [TGN](https://arxiv.org/pdf/2006.10637v2.pdf) model on the REDDIT dataset with MSPipe on 4 GPUs. 51 | 52 | ```sh 53 | cd scripts 54 | ./run_offline.sh TGN REDDIT 4 55 | ``` 56 | 57 | **Presample (TGL)** 58 | 59 | Training [TGN](https://arxiv.org/pdf/2006.10637v2.pdf) model on the REDDIT dataset with Presample on 4 GPUs. 60 | 61 | ```sh 62 | cd tgl 63 | ./run_tgl.sh TGN REDDIT 4 64 | ``` 65 | 66 | 67 | 68 | **Distributed training** 69 | 70 | Training TGN model on the GDELT dataset on more than 1 servers, each server is required to do the following step: 71 | 72 | 1. change the `INTERFACE` to your netcard name (can be found using`ifconfig`) 73 | 2. change the 74 | - `HOST_NODE_ADDR`: IP address of the host machine 75 | - `HOST_NODE_PORT`: The port of the host machine 76 | - `NNODES`: Total number of servers 77 | - `NPROC_PER_NODE`: The number of GPU for each servers 78 | 79 | ```sh 80 | cd script 81 | ./run_offline_dist.sh TGN GDELT 82 | ``` 83 | -------------------------------------------------------------------------------- /gnnflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic_graph import * 2 | from .temporal_sampler import * 3 | 4 | 5 | -------------------------------------------------------------------------------- /gnnflow/cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .lru_cache import LRUCache 2 | from .lfu_cache import LFUCache 3 | from .fifo_cache import FIFOCache 4 | from .gnnlab_static_cache import GNNLabStaticCache 5 | -------------------------------------------------------------------------------- /gnnflow/cache/fifo_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | 5 | from gnnflow.cache.cache import Cache 6 | from gnnflow.distributed.kvstore import KVStoreClient 7 | 8 | 9 | class FIFOCache(Cache): 10 | """ 11 | First-in-first-out cache 12 | """ 13 | 14 | def __init__(self, edge_cache_ratio: int, node_cache_ratio: int, 15 | num_nodes: int, num_edges: int, 16 | device: Union[str, torch.device], 17 | node_feats: Optional[torch.Tensor] = None, 18 | edge_feats: Optional[torch.Tensor] = None, 19 | dim_node_feat: Optional[int] = 0, 20 | dim_edge_feat: Optional[int] = 0, 21 | pinned_nfeat_buffs: Optional[torch.Tensor] = None, 22 | pinned_efeat_buffs: Optional[torch.Tensor] = None, 23 | kvstore_client: Optional[KVStoreClient] = None, 24 | distributed: Optional[bool] = False, 25 | neg_sample_ratio: Optional[int] = 1): 26 | """ 27 | Initialize the cache 28 | 29 | Args: 30 | edge_cache_ratio: The edge ratio of the cache size to the total number of nodes or edges 31 | range: [0, 1]. 32 | node_cache_ratio: The node ratio of the cache size to the total number of nodes or edges 33 | range: [0, 1]. 34 | num_nodes: The number of nodes in the graph 35 | num_edges: The number of edges in the graph 36 | device: The device to use 37 | node_feats: The node features 38 | edge_feats: The edge features 39 | dim_node_feat: The dimension of node features 40 | dim_edge_feat: The dimension of edge features 41 | pinned_nfeat_buffs: The pinned memory buffers for node features 42 | pinned_efeat_buffs: The pinned memory buffers for edge features 43 | kvstore_client: The KVStore_Client for fetching features when using distributed 44 | training 45 | distributed: Whether to use distributed training 46 | neg_sample_ratio: The ratio of negative samples to positive samples 47 | """ 48 | super(FIFOCache, self).__init__(edge_cache_ratio, node_cache_ratio, num_nodes, num_edges, device, 49 | node_feats, edge_feats, dim_node_feat, dim_edge_feat, 50 | pinned_nfeat_buffs, pinned_efeat_buffs, 51 | kvstore_client, distributed, neg_sample_ratio) 52 | self.name = 'fifo' 53 | # pointer to the last entry for the recent cached nodes 54 | self.cache_node_pointer = 0 55 | self.cache_edge_pointer = 0 56 | 57 | def init_cache(self, *args, **kwargs): 58 | """ 59 | Init the cache with features 60 | """ 61 | if self.distributed: 62 | return 63 | super(FIFOCache, self).init_cache(*args, **kwargs) 64 | if self.node_feats is not None: 65 | self.cache_node_pointer = self.node_capacity - 1 66 | 67 | if self.edge_feats is not None: 68 | self.cache_edge_pointer = self.edge_capacity - 1 69 | 70 | def reset(self): 71 | """ 72 | Reset the cache 73 | """ 74 | if self.dim_edge_feat != 0: 75 | self.cache_edge_pointer = self.edge_capacity - 1 76 | 77 | def update_node_cache(self, cached_node_index: torch.Tensor, 78 | uncached_node_id: torch.Tensor, 79 | uncached_node_feature: torch.Tensor): 80 | """ 81 | Update the node cache 82 | 83 | Args: 84 | cached_node_index: The index of the cached nodes 85 | uncached_node_id: The id of the uncached nodes 86 | uncached_node_feature: The features of the uncached nodes 87 | """ 88 | # If the number of nodes to cache is larger than the cache capacity, 89 | # we only cache the first self.capacity nodes 90 | if len(uncached_node_id) > self.node_capacity: 91 | num_node_to_cache = self.node_capacity 92 | else: 93 | num_node_to_cache = len(uncached_node_id) 94 | node_id_to_cache = uncached_node_id[:num_node_to_cache] 95 | node_feature_to_cache = uncached_node_feature[:num_node_to_cache] 96 | 97 | if self.cache_node_pointer + num_node_to_cache < self.node_capacity: 98 | removing_cache_index = torch.arange( 99 | self.cache_node_pointer + 1, self.cache_node_pointer + num_node_to_cache + 1) 100 | self.cache_node_pointer = self.cache_node_pointer + num_node_to_cache 101 | else: 102 | removing_cache_index = torch.cat([torch.arange(num_node_to_cache - (self.node_capacity - 1 - self.cache_node_pointer)), 103 | torch.arange(self.cache_node_pointer + 1, self.node_capacity)]) 104 | self.cache_node_pointer = num_node_to_cache - \ 105 | (self.node_capacity - 1 - self.cache_node_pointer) - 1 106 | assert len(removing_cache_index) == len( 107 | node_id_to_cache) == len(node_feature_to_cache) 108 | removing_cache_index = removing_cache_index.to( 109 | device=self.device, non_blocking=True) 110 | removing_node_id = self.cache_index_to_node_id[removing_cache_index] 111 | 112 | # update cache attributes 113 | self.cache_node_buffer[removing_cache_index] = node_feature_to_cache 114 | self.cache_node_flag[removing_node_id] = False 115 | self.cache_node_flag[node_id_to_cache] = True 116 | self.cache_node_map[removing_node_id] = -1 117 | self.cache_node_map[node_id_to_cache] = removing_cache_index 118 | self.cache_index_to_node_id[removing_cache_index] = node_id_to_cache 119 | 120 | def update_edge_cache(self, cached_edge_index: torch.Tensor, 121 | uncached_edge_id: torch.Tensor, 122 | uncached_edge_feature: torch.Tensor): 123 | """ 124 | Update the edge cache 125 | 126 | Args: 127 | cached_edge_index: The index of the cached edges 128 | uncached_edge_id: The id of the uncached edges 129 | uncached_edge_feature: The features of the uncached edges 130 | """ 131 | # If the number of edges to cache is larger than the cache capacity, 132 | # we only cache the first self.capacity edges 133 | if len(uncached_edge_id) > self.edge_capacity: 134 | num_edge_to_cache = self.edge_capacity 135 | else: 136 | num_edge_to_cache = len(uncached_edge_id) 137 | edge_id_to_cache = uncached_edge_id[:num_edge_to_cache] 138 | edge_feature_to_cache = uncached_edge_feature[:num_edge_to_cache] 139 | 140 | if self.cache_edge_pointer + num_edge_to_cache < self.edge_capacity: 141 | removing_cache_index = torch.arange( 142 | self.cache_edge_pointer + 1, self.cache_edge_pointer + num_edge_to_cache + 1) 143 | self.cache_edge_pointer = self.cache_edge_pointer + num_edge_to_cache 144 | else: 145 | removing_cache_index = torch.cat([torch.arange(num_edge_to_cache - (self.edge_capacity - 1 - self.cache_edge_pointer)), 146 | torch.arange(self.cache_edge_pointer + 1, self.edge_capacity)]) 147 | self.cache_edge_pointer = num_edge_to_cache - \ 148 | (self.edge_capacity - 1 - self.cache_edge_pointer) - 1 149 | assert len(removing_cache_index) == len( 150 | edge_id_to_cache) == len(edge_feature_to_cache) 151 | removing_cache_index = removing_cache_index.to( 152 | device=self.device, non_blocking=True) 153 | removing_edge_id = self.cache_index_to_edge_id[removing_cache_index] 154 | 155 | # update cache attributes 156 | self.cache_edge_buffer[removing_cache_index] = edge_feature_to_cache 157 | self.cache_edge_flag[removing_edge_id] = False 158 | self.cache_edge_flag[edge_id_to_cache] = True 159 | self.cache_edge_map[removing_edge_id] = -1 160 | self.cache_edge_map[edge_id_to_cache] = removing_cache_index 161 | self.cache_index_to_edge_id[removing_cache_index] = edge_id_to_cache 162 | -------------------------------------------------------------------------------- /gnnflow/cache/gnnlab_static_cache.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from dgl.heterograph import DGLBlock 6 | 7 | from gnnflow.cache.cache import Cache 8 | from gnnflow.distributed.kvstore import KVStoreClient 9 | from gnnflow.temporal_sampler import TemporalSampler 10 | from gnnflow.utils import get_batch_no_neg 11 | 12 | 13 | class GNNLabStaticCache(Cache): 14 | """ 15 | GNNLab static cache 16 | 17 | paper: https://dl.acm.org/doi/abs/10.1145/3492321.3519557 18 | """ 19 | 20 | def __init__(self, cache_ratio: int, num_nodes: int, num_edges: int, 21 | device: Union[str, torch.device], 22 | node_feats: Optional[torch.Tensor] = None, 23 | edge_feats: Optional[torch.Tensor] = None, 24 | dim_node_feat: Optional[int] = 0, 25 | dim_edge_feat: Optional[int] = 0, 26 | pinned_nfeat_buffs: Optional[torch.Tensor] = None, 27 | pinned_efeat_buffs: Optional[torch.Tensor] = None, 28 | kvstore_client: Optional[KVStoreClient] = None, 29 | distributed: Optional[bool] = False, 30 | neg_sample_ratio: Optional[int] = 1): 31 | """ 32 | Initialize the cache 33 | 34 | Args: 35 | cache_ratio: The ratio of the cache size to the total number of nodes or edges 36 | range: [0, 1]. 37 | num_nodes: The number of nodes in the graph 38 | num_edges: The number of edges in the graph 39 | device: The device to use 40 | node_feats: The node features 41 | edge_feats: The edge features 42 | dim_node_feat: The dimension of node features 43 | dim_edge_feat: The dimension of edge features 44 | pinned_nfeat_buffs: The pinned memory buffers for node features 45 | pinned_efeat_buffs: The pinned memory buffers for edge features 46 | kvstore_client: The KVStore_Client for fetching features when using distributed 47 | training 48 | distributed: Whether to use distributed training 49 | neg_sample_ratio: The ratio of negative samples to positive samples 50 | """ 51 | super(GNNLabStaticCache, self).__init__(cache_ratio, num_nodes, 52 | num_edges, device, 53 | node_feats, edge_feats, 54 | dim_node_feat, dim_edge_feat, 55 | pinned_nfeat_buffs, 56 | pinned_efeat_buffs, 57 | kvstore_client, distributed, 58 | neg_sample_ratio) 59 | # name 60 | self.name = 'gnnlab' 61 | 62 | self.cache_index_to_node_id = None 63 | self.cache_index_to_edge_id = None 64 | 65 | def reset(self): 66 | """Reset the cache""" 67 | # do nothing 68 | return 69 | 70 | def get_mem_size(self) -> int: 71 | """ 72 | Get the memory size of the cache in bytes 73 | """ 74 | mem_size = 0 75 | if self.dim_node_feat != 0: 76 | mem_size += self.cache_node_buffer.element_size() * self.cache_node_buffer.nelement() 77 | mem_size += self.cache_node_flag.element_size() * self.cache_node_flag.nelement() 78 | mem_size += self.cache_node_map.element_size() * self.cache_node_map.nelement() 79 | 80 | if self.dim_edge_feat != 0: 81 | mem_size += self.cache_edge_buffer.element_size() * self.cache_edge_buffer.nelement() 82 | mem_size += self.cache_edge_flag.element_size() * self.cache_edge_flag.nelement() 83 | mem_size += self.cache_edge_map.element_size() * self.cache_edge_map.nelement() 84 | 85 | return mem_size 86 | 87 | def init_cache(self, *args, **kwargs): 88 | """ 89 | Init the caching with features 90 | """ 91 | node_sampled_count = torch.zeros(self.num_nodes, dtype=torch.int32) 92 | edge_sampled_count = torch.zeros(self.num_edges, dtype=torch.int32) 93 | eid_to_nid = torch.zeros(self.num_edges, dtype=torch.int64) 94 | 95 | sampler = kwargs['sampler'] 96 | train_df = kwargs['train_df'] 97 | pre_sampling_rounds = kwargs.get('pre_sampling_rounds', 2) 98 | batch_size = kwargs.get('batch_size', 600) 99 | 100 | # Do sampling for multiple rounds 101 | for _ in range(pre_sampling_rounds): 102 | for target_nodes, ts, _ in get_batch_no_neg(train_df, batch_size): 103 | mfgs = sampler.sample(target_nodes, ts) 104 | if self.node_feats is not None or self.dim_node_feat != 0: 105 | for b in mfgs[0]: 106 | node_sampled_count[b.srcdata['ID']] += 1 107 | if self.edge_feats is not None or self.dim_edge_feat != 0: 108 | for mfg in mfgs: 109 | for b in mfg: 110 | if b.num_src_nodes() > b.num_dst_nodes(): 111 | edge_sampled_count[b.edata['ID']] += 1 112 | eid_to_nid[b.edata['ID'] 113 | ] = b.srcdata['ID'][b.num_dst_nodes():] 114 | 115 | if self.distributed: 116 | if self.dim_node_feat != 0 and self.node_capacity > 0: 117 | # Get the top-k nodes with the highest sampling count 118 | cache_node_id = torch.topk( 119 | node_sampled_count, k=self.node_capacity, largest=True).indices.to(self.device) 120 | 121 | # Init parameters related to feature fetching 122 | cache_node_index = torch.arange( 123 | self.node_capacity, dtype=torch.int64).to(self.device) 124 | self.cache_node_buffer[cache_node_index] = self.kvstore_client.pull( 125 | cache_node_id.cpu(), mode='node').to(self.device) 126 | self.cache_node_flag[cache_node_id] = True 127 | self.cache_node_map[cache_node_id] = cache_node_index 128 | else: 129 | if self.node_feats is not None: 130 | # Get the top-k nodes with the highest sampling count 131 | cache_node_id = torch.topk( 132 | node_sampled_count, k=self.node_capacity, largest=True).indices.to(self.device) 133 | 134 | # Init parameters related to feature fetching 135 | cache_node_index = torch.arange( 136 | self.node_capacity, dtype=torch.int64).to(self.device) 137 | self.cache_node_buffer[cache_node_index] = self.node_feats[cache_node_id].to( 138 | self.device, non_blocking=True) 139 | self.cache_node_flag[cache_node_id] = True 140 | self.cache_node_map[cache_node_id] = cache_node_index 141 | 142 | if self.distributed: 143 | if self.dim_edge_feat != 0 and self.edge_capacity > 0: 144 | # Get the top-k edges with the highest sampling count 145 | cache_edge_id = torch.topk( 146 | edge_sampled_count, k=self.edge_capacity, largest=True).indices.to(self.device) 147 | 148 | # Init parameters related to feature fetching 149 | cache_edge_index = torch.arange( 150 | self.edge_capacity, dtype=torch.int64).to(self.device) 151 | self.cache_edge_buffer[cache_edge_index] = self.kvstore_client.pull( 152 | cache_edge_id.cpu(), mode='edge', nid=eid_to_nid[cache_edge_id.cpu()]).to(self.device) 153 | self.cache_edge_flag[cache_edge_id] = True 154 | self.cache_edge_map[cache_edge_id] = cache_edge_index 155 | 156 | else: 157 | if self.edge_feats is not None: 158 | # Get the top-k edges with the highest sampling count 159 | cache_edge_id = torch.topk( 160 | edge_sampled_count, k=self.edge_capacity, largest=True).indices.to(self.device) 161 | 162 | # Init parameters related to feature fetching 163 | cache_edge_index = torch.arange( 164 | self.edge_capacity, dtype=torch.int64).to(self.device) 165 | self.cache_edge_buffer[cache_edge_index] = self.edge_feats[cache_edge_id].to( 166 | self.device, non_blocking=True) 167 | self.cache_edge_flag[cache_edge_id] = True 168 | self.cache_edge_map[cache_edge_id] = cache_edge_index 169 | 170 | def fetch_feature(self, mfgs: List[List[DGLBlock]], 171 | eid: Optional[np.ndarray] = None, update_cache: bool = True, 172 | target_edge_features: bool = True): 173 | """Fetching the node features of input_node_ids 174 | 175 | Args: 176 | mfgs: message-passing flow graphs 177 | update_cache: whether to update the cache 178 | 179 | Returns: 180 | mfgs: message-passing flow graphs with node/edge features 181 | """ 182 | return super(GNNLabStaticCache, self).fetch_feature(mfgs, eid=eid, update_cache=False, target_edge_features=target_edge_features) 183 | -------------------------------------------------------------------------------- /gnnflow/cache/lfu_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | 5 | from gnnflow.cache.cache import Cache 6 | from gnnflow.distributed.kvstore import KVStoreClient 7 | 8 | 9 | class LFUCache(Cache): 10 | """ 11 | Least-frequently-used (LFU) Cache 12 | """ 13 | 14 | def __init__(self, edge_cache_ratio: int, node_cache_ratio: int, 15 | num_nodes: int, num_edges: int, 16 | device: Union[str, torch.device], 17 | node_feats: Optional[torch.Tensor] = None, 18 | edge_feats: Optional[torch.Tensor] = None, 19 | dim_node_feat: Optional[int] = 0, 20 | dim_edge_feat: Optional[int] = 0, 21 | pinned_nfeat_buffs: Optional[torch.Tensor] = None, 22 | pinned_efeat_buffs: Optional[torch.Tensor] = None, 23 | kvstore_client: Optional[KVStoreClient] = None, 24 | distributed: Optional[bool] = False, 25 | neg_sample_ratio: Optional[int] = 1): 26 | """ 27 | Initialize the cache 28 | 29 | Args: 30 | edge_cache_ratio: The edge ratio of the cache size to the total number of nodes or edges 31 | range: [0, 1]. 32 | node_cache_ratio: The node ratio of the cache size to the total number of nodes or edges 33 | range: [0, 1]. 34 | num_nodes: The number of nodes in the graph 35 | num_edges: The number of edges in the graph 36 | device: The device to use 37 | node_feats: The node features 38 | edge_feats: The edge features 39 | dim_node_feat: The dimension of node features 40 | dim_edge_feat: The dimension of edge features 41 | pinned_nfeat_buffs: The pinned memory buffers for node features 42 | pinned_efeat_buffs: The pinned memory buffers for edge features 43 | kvstore_client: The KVStore_Client for fetching features when using distributed 44 | training 45 | distributed: Whether to use distributed training 46 | neg_sample_ratio: The ratio of negative samples to positive samples 47 | """ 48 | super(LFUCache, self).__init__(edge_cache_ratio, node_cache_ratio, num_nodes, 49 | num_edges, device, node_feats, 50 | edge_feats, dim_node_feat, 51 | dim_edge_feat, pinned_nfeat_buffs, 52 | pinned_efeat_buffs, kvstore_client, 53 | distributed, neg_sample_ratio) 54 | self.name = 'lfu' 55 | 56 | if self.dim_node_feat != 0: 57 | self.cache_node_count = torch.zeros( 58 | self.node_capacity, dtype=torch.int32, device=self.device) 59 | if self.dim_edge_feat != 0: 60 | self.cache_edge_count = torch.zeros( 61 | self.edge_capacity, dtype=torch.int32, device=self.device) 62 | 63 | def get_mem_size(self) -> int: 64 | """ 65 | Get the memory size of the cache in bytes 66 | """ 67 | mem_size = super(LFUCache, self).get_mem_size() 68 | if self.dim_node_feat != 0: 69 | mem_size += self.cache_node_count.element_size() * self.cache_node_count.nelement() 70 | if self.dim_edge_feat != 0: 71 | mem_size += self.cache_edge_count.element_size() * self.cache_edge_count.nelement() 72 | return mem_size 73 | 74 | def init_cache(self, *args, **kwargs): 75 | """ 76 | Init the caching with features 77 | """ 78 | if self.distributed: 79 | return 80 | super(LFUCache, self).init_cache(*args, **kwargs) 81 | if self.dim_node_feat != 0: 82 | self.cache_node_count[self.cache_index_to_node_id] += 1 83 | 84 | if self.dim_edge_feat != 0: 85 | self.cache_edge_count[self.cache_index_to_edge_id] += 1 86 | 87 | def reset(self): 88 | """ 89 | Reset the cache 90 | """ 91 | # NB: only edge cache is reset 92 | if self.distributed: 93 | if self.dim_edge_feat != 0 and self.edge_capacity > 0: 94 | keys, feats = self.kvstore_client.init_cache( 95 | self.edge_capacity) 96 | cache_edge_id = torch.arange( 97 | len(keys), dtype=torch.int64, device=self.device) 98 | self.cache_edge_buffer[cache_edge_id] = feats.to( 99 | self.device).float() 100 | self.cache_edge_flag[cache_edge_id] = True 101 | self.cache_index_to_edge_id[cache_edge_id] = keys.to( 102 | self.device) 103 | self.cache_edge_map[keys] = cache_edge_id 104 | 105 | self.cache_edge_count.zero_() 106 | else: 107 | if self.edge_feats is not None: 108 | cache_edge_id = torch.arange( 109 | self.edge_capacity, dtype=torch.int64, device=self.device) 110 | 111 | # Init parameters related to feature fetching 112 | self.cache_edge_buffer[cache_edge_id] = self.edge_feats[:self.edge_capacity].to( 113 | self.device, non_blocking=True) 114 | self.cache_edge_flag[cache_edge_id] = True 115 | self.cache_index_to_edge_id = cache_edge_id 116 | self.cache_edge_map[cache_edge_id] = cache_edge_id 117 | 118 | self.cache_edge_count.zero_() 119 | 120 | def resize(self, new_num_nodes: int, new_num_edges: int): 121 | """ 122 | Resize the cache 123 | 124 | Args: 125 | new_num_nodes: The new number of nodes 126 | new_num_edges: The new number of edges 127 | """ 128 | super(LFUCache, self).resize(new_num_nodes, new_num_edges) 129 | if self.dim_node_feat != 0: 130 | self.cache_node_count.resize_(self.node_capacity) 131 | if self.dim_edge_feat != 0: 132 | self.cache_edge_count.resize_(self.edge_capacity) 133 | 134 | def update_node_cache(self, cached_node_index: torch.Tensor, 135 | uncached_node_id: torch.Tensor, 136 | uncached_node_feature: torch.Tensor): 137 | """ 138 | Update the node cache 139 | 140 | Args: 141 | cached_node_index: The index of the cached nodes 142 | uncached_node_id: The id of the uncached nodes 143 | uncached_node_feature: The features of the uncached nodes 144 | """ 145 | # If the number of nodes to cache is larger than the cache capacity, 146 | # we only cache the first self.capacity nodes 147 | if len(uncached_node_id) > self.node_capacity: 148 | num_node_to_cache = self.node_capacity 149 | else: 150 | num_node_to_cache = len(uncached_node_id) 151 | node_id_to_cache = uncached_node_id[:num_node_to_cache] 152 | node_feature_to_cache = uncached_node_feature[:num_node_to_cache] 153 | 154 | # update cached node index first 155 | self.cache_node_count[cached_node_index] += 1 156 | 157 | # get the k node id with the least water level 158 | removing_cache_index = torch.topk( 159 | self.cache_node_count, k=num_node_to_cache, largest=False).indices 160 | assert len(removing_cache_index) == len( 161 | node_id_to_cache) == len(node_feature_to_cache) 162 | removing_node_id = self.cache_index_to_node_id[removing_cache_index] 163 | 164 | # update cache attributes 165 | self.cache_node_buffer[removing_cache_index] = node_feature_to_cache 166 | self.cache_node_count[removing_cache_index] = 1 167 | self.cache_node_flag[removing_node_id] = False 168 | self.cache_node_flag[node_id_to_cache] = True 169 | self.cache_node_map[removing_node_id] = -1 170 | self.cache_node_map[node_id_to_cache] = removing_cache_index 171 | self.cache_index_to_node_id[removing_cache_index] = node_id_to_cache 172 | 173 | def update_edge_cache(self, cached_edge_index: torch.Tensor, 174 | uncached_edge_id: torch.Tensor, 175 | uncached_edge_feature: torch.Tensor): 176 | """ 177 | Update the edge cache 178 | 179 | Args: 180 | cached_edge_index: The index of the cached edges 181 | uncached_edge_id: The id of the uncached edges 182 | uncached_edge_feature: The features of the uncached edges 183 | """ 184 | # If the number of edges to cache is larger than the cache capacity, 185 | # we only cache the first self.capacity edges 186 | if len(uncached_edge_id) > self.edge_capacity: 187 | num_edge_to_cache = self.edge_capacity 188 | else: 189 | num_edge_to_cache = len(uncached_edge_id) 190 | edge_id_to_cache = uncached_edge_id[:num_edge_to_cache] 191 | edge_feature_to_cache = uncached_edge_feature[:num_edge_to_cache] 192 | 193 | # update cached edge index first 194 | self.cache_edge_count[cached_edge_index] += 1 195 | 196 | # get the k edge id with the least water level 197 | removing_cache_index = torch.topk( 198 | self.cache_edge_count, k=num_edge_to_cache, largest=False).indices 199 | assert len(removing_cache_index) == len( 200 | edge_id_to_cache) == len(edge_feature_to_cache) 201 | removing_edge_id = self.cache_index_to_edge_id[removing_cache_index] 202 | 203 | # update cache attributes 204 | self.cache_edge_buffer[removing_cache_index] = edge_feature_to_cache 205 | self.cache_edge_count[removing_cache_index] = 1 206 | self.cache_edge_flag[removing_edge_id] = False 207 | self.cache_edge_flag[edge_id_to_cache] = True 208 | self.cache_edge_map[removing_edge_id] = -1 209 | self.cache_edge_map[edge_id_to_cache] = removing_cache_index 210 | self.cache_index_to_edge_id[removing_cache_index] = edge_id_to_cache 211 | -------------------------------------------------------------------------------- /gnnflow/cache/lru_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | 5 | from gnnflow.cache.cache import Cache 6 | from gnnflow.distributed.kvstore import KVStoreClient 7 | 8 | 9 | class LRUCache(Cache): 10 | """ 11 | Least-recently-used (LRU) cache 12 | """ 13 | 14 | def __init__(self, edge_cache_ratio: int, node_cache_ratio: int, 15 | num_nodes: int, num_edges: int, 16 | device: Union[str, torch.device], 17 | node_feats: Optional[torch.Tensor] = None, 18 | edge_feats: Optional[torch.Tensor] = None, 19 | dim_node_feat: Optional[int] = 0, 20 | dim_edge_feat: Optional[int] = 0, 21 | pinned_nfeat_buffs: Optional[torch.Tensor] = None, 22 | pinned_efeat_buffs: Optional[torch.Tensor] = None, 23 | kvstore_client: Optional[KVStoreClient] = None, 24 | distributed: Optional[bool] = False, 25 | neg_sample_ratio: Optional[int] = 1): 26 | """ 27 | Initialize the cache 28 | 29 | Args: 30 | edge_cache_ratio: The edge ratio of the cache size to the total number of nodes or edges 31 | range: [0, 1]. 32 | node_cache_ratio: The node ratio of the cache size to the total number of nodes or edges 33 | range: [0, 1]. 34 | num_nodes: The number of nodes in the graph 35 | num_edges: The number of edges in the graph 36 | device: The device to use 37 | node_feats: The node features 38 | edge_feats: The edge features 39 | dim_node_feat: The dimension of node features 40 | dim_edge_feat: The dimension of edge features 41 | pinned_nfeat_buffs: The pinned memory buffers for node features 42 | pinned_efeat_buffs: The pinned memory buffers for edge features 43 | kvstore_client: The KVStore_Client for fetching features when using distributed 44 | training 45 | distributed: Whether to use distributed training 46 | neg_sample_ratio: The ratio of negative samples to positive samples 47 | """ 48 | super(LRUCache, self).__init__(edge_cache_ratio, node_cache_ratio, num_nodes, 49 | num_edges, device, node_feats, 50 | edge_feats, dim_node_feat, 51 | dim_edge_feat, pinned_nfeat_buffs, 52 | pinned_efeat_buffs, kvstore_client, 53 | distributed, neg_sample_ratio) 54 | self.name = 'lru' 55 | 56 | if self.dim_node_feat != 0: 57 | self.cache_node_count = torch.zeros( 58 | self.node_capacity, dtype=torch.int32, device=self.device) 59 | if self.dim_edge_feat != 0: 60 | self.cache_edge_count = torch.zeros( 61 | self.edge_capacity, dtype=torch.int32, device=self.device) 62 | 63 | def get_mem_size(self) -> int: 64 | """ 65 | Get the memory size of the cache in bytes 66 | """ 67 | mem_size = super(LRUCache, self).get_mem_size() 68 | if self.dim_node_feat != 0: 69 | mem_size += self.cache_node_count.element_size() * self.cache_node_count.nelement() 70 | if self.dim_edge_feat != 0: 71 | mem_size += self.cache_edge_count.element_size() * self.cache_edge_count.nelement() 72 | return mem_size 73 | 74 | def reset(self): 75 | """ 76 | Reset the cache 77 | """ 78 | if self.kvstore_client is not None: 79 | self.kvstore_client.comm_time = 0 80 | # NB: only edge cache is reset 81 | if self.distributed: 82 | if self.dim_edge_feat != 0 and self.edge_capacity > 0: 83 | keys, feats = self.kvstore_client.init_cache( 84 | self.edge_capacity) 85 | cache_edge_id = torch.arange( 86 | len(keys), dtype=torch.int64, device=self.device) 87 | self.cache_edge_buffer[cache_edge_id] = feats.to( 88 | self.device).float() 89 | self.cache_edge_flag[cache_edge_id] = True 90 | self.cache_index_to_edge_id[cache_edge_id] = keys.to( 91 | self.device) 92 | self.cache_edge_map[keys] = cache_edge_id 93 | 94 | self.cache_edge_count.zero_() 95 | else: 96 | if self.edge_feats is not None: 97 | cache_edge_id = torch.arange( 98 | self.edge_capacity, dtype=torch.int64, device=self.device) 99 | 100 | # Init parameters related to feature fetching 101 | self.cache_edge_buffer[cache_edge_id] = self.edge_feats[:self.edge_capacity].to( 102 | self.device, non_blocking=True) 103 | self.cache_edge_flag[cache_edge_id] = True 104 | self.cache_index_to_edge_id = cache_edge_id 105 | self.cache_edge_map[cache_edge_id] = cache_edge_id 106 | 107 | self.cache_edge_count.zero_() 108 | 109 | def resize(self, new_num_nodes: int, new_num_edges: int): 110 | """ 111 | Resize the cache 112 | 113 | Args: 114 | new_num_nodes: The new number of nodes 115 | new_num_edges: The new number of edges 116 | """ 117 | super(LRUCache, self).resize(new_num_nodes, new_num_edges) 118 | if self.dim_node_feat != 0: 119 | self.cache_node_count.resize_(self.node_capacity) 120 | if self.dim_edge_feat != 0: 121 | self.cache_edge_count.resize_(self.edge_capacity) 122 | 123 | def update_node_cache(self, cached_node_index: torch.Tensor, 124 | uncached_node_id: torch.Tensor, 125 | uncached_node_feature: torch.Tensor): 126 | """ 127 | Update the node cache 128 | 129 | Args: 130 | cached_node_index: The index of the cached nodes 131 | uncached_node_id: The id of the uncached nodes 132 | uncached_node_feature: The features of the uncached nodes 133 | """ 134 | # If the number of nodes to cache is larger than the cache capacity, 135 | # we only cache the first self.capacity nodes 136 | if len(uncached_node_id) > self.node_capacity: 137 | num_node_to_cache = self.node_capacity 138 | else: 139 | num_node_to_cache = len(uncached_node_id) 140 | node_id_to_cache = uncached_node_id[:num_node_to_cache] 141 | node_feature_to_cache = uncached_node_feature[:num_node_to_cache] 142 | 143 | # first all -1 144 | self.cache_node_count -= 1 145 | # update cached node index to 0 (0 is the highest priority) 146 | self.cache_node_count[cached_node_index] = 0 147 | 148 | # get the k node id with the least water level 149 | removing_cache_index = torch.topk( 150 | self.cache_node_count, k=num_node_to_cache, largest=False).indices 151 | assert len(removing_cache_index) == len( 152 | node_id_to_cache) == len(node_feature_to_cache) 153 | removing_node_id = self.cache_index_to_node_id[removing_cache_index] 154 | 155 | # update cache attributes 156 | self.cache_node_buffer[removing_cache_index] = node_feature_to_cache 157 | self.cache_node_count[removing_cache_index] = 0 158 | self.cache_node_flag[removing_node_id] = False 159 | self.cache_node_flag[node_id_to_cache] = True 160 | self.cache_node_map[removing_node_id] = -1 161 | self.cache_node_map[node_id_to_cache] = removing_cache_index 162 | self.cache_index_to_node_id[removing_cache_index] = node_id_to_cache 163 | 164 | def update_edge_cache(self, cached_edge_index: torch.Tensor, 165 | uncached_edge_id: torch.Tensor, 166 | uncached_edge_feature: torch.Tensor): 167 | """ 168 | Update the edge cache 169 | 170 | Args: 171 | cached_edge_index: The index of the cached edges 172 | uncached_edge_id: The id of the uncached edges 173 | uncached_edge_feature: The features of the uncached edges 174 | """ 175 | # If the number of edges to cache is larger than the cache capacity, 176 | # we only cache the first self.capacity edges 177 | if len(uncached_edge_id) > self.edge_capacity: 178 | num_edge_to_cache = self.edge_capacity 179 | else: 180 | num_edge_to_cache = len(uncached_edge_id) 181 | edge_id_to_cache = uncached_edge_id[:num_edge_to_cache] 182 | edge_feature_to_cache = uncached_edge_feature[:num_edge_to_cache] 183 | 184 | # first all -1 185 | self.cache_edge_count -= 1 186 | # update cached edge index to 0 (0 is the highest priority) 187 | self.cache_edge_count[cached_edge_index] = 0 188 | 189 | # get the k edge id with the least water level 190 | removing_cache_index = torch.topk( 191 | self.cache_edge_count, k=num_edge_to_cache, largest=False).indices 192 | assert len(removing_cache_index) == len( 193 | edge_id_to_cache) == len(edge_feature_to_cache) 194 | removing_edge_id = self.cache_index_to_edge_id[removing_cache_index] 195 | 196 | # update cache attributes 197 | self.cache_edge_buffer[removing_cache_index] = edge_feature_to_cache 198 | self.cache_edge_count[removing_cache_index] = 0 199 | self.cache_edge_flag[removing_edge_id] = False 200 | self.cache_edge_flag[edge_id_to_cache] = True 201 | self.cache_edge_map[removing_edge_id] = -1 202 | self.cache_edge_map[edge_id_to_cache] = removing_cache_index 203 | self.cache_index_to_edge_id[removing_cache_index] = edge_id_to_cache 204 | -------------------------------------------------------------------------------- /gnnflow/config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | MiB = 1 << 20 4 | GiB = 1 << 30 5 | 6 | 7 | def get_default_config(model: str, dataset: str): 8 | """ 9 | Get default configuration for a model and dataset. 10 | 11 | Args: 12 | model: Model name. 13 | dataset: Name of the dataset. 14 | 15 | Returns: 16 | Default configuration for the model and dataset. 17 | """ 18 | model, dataset = model.lower(), dataset.lower() 19 | assert model in ["tgn", "tgat", "dysat", "graphsage", "gat", "jodie", 'apan'] and dataset in [ 20 | "wiki", "reddit", "mooc", "lastfm", "gdelt", "mag"], "Invalid model or dataset." 21 | 22 | mod = sys.modules[__name__] 23 | return getattr( 24 | mod, f"_{model}_default_config"), getattr( 25 | mod, f"_{dataset}_default_config") 26 | 27 | 28 | _tgn_default_config = { 29 | "dropout": 0.1, 30 | "att_head": 2, 31 | "att_dropout": 0.1, 32 | "num_layers": 1, 33 | "fanouts": [10], 34 | "sample_strategy": "recent", 35 | "num_snapshots": 1, 36 | "snapshot_time_window": 0, 37 | "prop_time": False, 38 | "use_memory": True, 39 | "dim_time": 100, 40 | "dim_embed": 100, 41 | "dim_memory": 100, 42 | "batch_size": 600 43 | } 44 | 45 | _apan_default_config = { 46 | "dropout": 0.1, 47 | "att_head": 2, 48 | "att_dropout": 0.1, 49 | "num_layers": 1, 50 | "fanouts": [10], 51 | "sample_strategy": "recent", 52 | "num_snapshots": 1, 53 | "snapshot_time_window": 0, 54 | "prop_time": False, 55 | "use_memory": True, 56 | "dim_time": 100, 57 | "dim_embed": 100, 58 | "dim_memory": 100, 59 | "batch_size": 600 60 | } 61 | 62 | _jodie_default_config = { 63 | "dropout": 0.1, 64 | "att_head": 2, 65 | "att_dropout": 0.1, 66 | "num_layers": 1, 67 | "fanouts": [10], 68 | "sample_strategy": "recent", 69 | "num_snapshots": 1, 70 | "snapshot_time_window": 0, 71 | "prop_time": False, 72 | "use_memory": True, 73 | "dim_time": 100, 74 | "dim_embed": 100, 75 | "dim_memory": 100, 76 | "batch_size": 600 77 | } 78 | 79 | _tgat_default_config = { 80 | "dropout": 0.1, 81 | "att_head": 2, 82 | "att_dropout": 0.1, 83 | "num_layers": 2, 84 | "fanouts": [10, 10], 85 | "sample_strategy": "uniform", 86 | "num_snapshots": 1, 87 | "snapshot_time_window": 0, 88 | "prop_time": False, 89 | "use_memory": False, 90 | "dim_time": 100, 91 | "dim_embed": 100, 92 | "batch_size": 600 93 | } 94 | 95 | _dysat_default_config = { 96 | "dropout": 0.1, 97 | "att_head": 2, 98 | "att_head": 2, 99 | "att_dropout": 0.1, 100 | "num_layers": 2, 101 | "fanouts": [10, 10], 102 | "sample_strategy": "uniform", 103 | "num_snapshots": 3, 104 | "snapshot_time_window": 10000, 105 | "prop_time": True, 106 | "use_memory": False, 107 | "dim_time": 0, 108 | "dim_embed": 100, 109 | "batch_size": 4000 110 | } 111 | 112 | _graphsage_default_config = { 113 | "dim_embed": 100, 114 | "num_layers": 2, 115 | "aggregator": 'mean', 116 | "fanouts": [15, 10], 117 | "sample_strategy": "uniform", 118 | "num_snapshots": 1, 119 | "snapshot_time_window": 0, 120 | "prop_time": False, 121 | "use_memory": False, 122 | "is_static": True, 123 | "batch_size": 1200 124 | } 125 | 126 | _gat_default_config = { 127 | "dropout": 0.1, 128 | "att_head": 2, 129 | "att_dropout": 0.1, 130 | "num_layers": 2, 131 | "fanouts": [10, 10], 132 | "sample_strategy": "uniform", 133 | "num_snapshots": 1, 134 | "snapshot_time_window": 0, 135 | "prop_time": False, 136 | "use_memory": False, 137 | "dim_time": 0, 138 | "dim_embed": 100, 139 | "is_static": True, 140 | "batch_size": 600 141 | } 142 | 143 | _wiki_default_config = { 144 | "initial_pool_size": 10 * MiB, 145 | "maximum_pool_size": 30 * MiB, 146 | "mem_resource_type": "cuda", 147 | "minimum_block_size": 18, 148 | "blocks_to_preallocate": 1024, 149 | "insertion_policy": "insert", 150 | "undirected": True, 151 | "node_feature": False, 152 | "edge_feature": True, 153 | } 154 | 155 | _reddit_default_config = { 156 | "initial_pool_size": 20 * MiB, 157 | "maximum_pool_size": 1000 * MiB, 158 | "mem_resource_type": "cuda", 159 | "minimum_block_size": 62, 160 | "blocks_to_preallocate": 1024, 161 | "insertion_policy": "insert", 162 | "undirected": False, 163 | "node_feature": True, 164 | "edge_feature": True, 165 | } 166 | 167 | _mooc_default_config = { 168 | "initial_pool_size": 20 * MiB, 169 | "maximum_pool_size": 50 * MiB, 170 | "mem_resource_type": "cuda", 171 | "minimum_block_size": 59, 172 | "blocks_to_preallocate": 1024, 173 | "insertion_policy": "insert", 174 | "undirected": False, 175 | "node_feature": False, 176 | "edge_feature": True, 177 | } 178 | 179 | _lastfm_default_config = { 180 | "initial_pool_size": 50 * MiB, 181 | "maximum_pool_size": 100 * MiB, 182 | "mem_resource_type": "cuda", 183 | "minimum_block_size": 650, 184 | "blocks_to_preallocate": 1024, 185 | "insertion_policy": "insert", 186 | "undirected": True, 187 | "node_feature": False, 188 | "edge_feature": True, 189 | } 190 | 191 | _gdelt_default_config = { 192 | "initial_pool_size": 10*GiB, 193 | "maximum_pool_size": 20*GiB, 194 | "mem_resource_type": "unified", 195 | "minimum_block_size": 123, 196 | "blocks_to_preallocate": 8196, 197 | "insertion_policy": "insert", 198 | "undirected": False, 199 | "node_feature": True, 200 | "edge_feature": True, 201 | } 202 | 203 | _mag_default_config = { 204 | "initial_pool_size": 5*GiB, 205 | "maximum_pool_size": 50*GiB, 206 | "mem_resource_type": "unified", 207 | "minimum_block_size": 11, 208 | "blocks_to_preallocate": 65536, 209 | "insertion_policy": "insert", 210 | "undirected": False, 211 | "node_feature": True, 212 | "edge_feature": False, 213 | } 214 | -------------------------------------------------------------------------------- /gnnflow/csrc/api.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | #include "common.h" 8 | #include "dynamic_graph.h" 9 | #include "kvstore.h" 10 | #include "temporal_sampler.h" 11 | 12 | namespace py = pybind11; 13 | 14 | using namespace gnnflow; 15 | 16 | template 17 | inline py::array vec2npy(const std::vector &vec) { 18 | // need to let python garbage collector handle C++ vector memory 19 | // see https://github.com/pybind/pybind11/issues/1042 20 | auto v = new std::vector(vec); 21 | auto capsule = py::capsule( 22 | v, [](void *v) { delete reinterpret_cast *>(v); }); 23 | return py::array(v->size(), v->data(), capsule); 24 | } 25 | 26 | PYBIND11_MODULE(libgnnflow, m) { 27 | py::enum_(m, "InsertionPolicy") 28 | .value("INSERT", InsertionPolicy::kInsertionPolicyInsert) 29 | .value("REPLACE", InsertionPolicy::kInsertionPolicyReplace); 30 | 31 | py::enum_(m, "SamplingPolicy") 32 | .value("RECENT", SamplingPolicy::kSamplingPolicyRecent) 33 | .value("UNIFORM", SamplingPolicy::kSamplingPolicyUniform); 34 | 35 | py::enum_(m, "MemoryResourceType") 36 | .value("CUDA", MemoryResourceType::kMemoryResourceTypeCUDA) 37 | .value("UNIFIED", MemoryResourceType::kMemoryResourceTypeUnified) 38 | .value("PINNED", MemoryResourceType::kMemoryResourceTypePinned) 39 | .value("SHARED", MemoryResourceType::kMemoryResourceTypeShared); 40 | 41 | py::class_(m, "_DynamicGraph") 42 | .def(py::init(), 44 | py::arg("initial_pool_size"), py::arg("maximum_pool_size"), 45 | py::arg("mem_resource_type"), py::arg("minium_block_size"), 46 | py::arg("blocks_to_preallocate"), py::arg("insertion_policy"), 47 | py::arg("device"), py::arg("adaptive_block_size")) 48 | .def("add_edges", &DynamicGraph::AddEdges, py::arg("source_vertices"), 49 | py::arg("target_vertices"), py::arg("timestamps"), py::arg("eids")) 50 | .def("offload_old_blocks", &DynamicGraph::OffloadOldBlocks, 51 | py::arg("timestamp"), py::arg("to_file") = false) 52 | .def("num_vertices", &DynamicGraph::num_nodes) 53 | .def("num_source_vertices", &DynamicGraph::num_src_nodes) 54 | .def("num_edges", &DynamicGraph::num_edges) 55 | .def("out_degree", 56 | [](const DynamicGraph &dgraph, std::vector nodes) { 57 | return vec2npy(dgraph.out_degree(nodes)); 58 | }) 59 | .def("nodes", 60 | [](const DynamicGraph &dgraph) { return vec2npy(dgraph.nodes()); }) 61 | .def("src_nodes", 62 | [](const DynamicGraph &dgraph) { 63 | return vec2npy(dgraph.src_nodes()); 64 | }) 65 | .def("edges", 66 | [](const DynamicGraph &dgraph) { return vec2npy(dgraph.edges()); }) 67 | .def("max_vertex_id", 68 | [](const DynamicGraph &dgraph) { return dgraph.max_node_id(); }) 69 | .def("get_temporal_neighbors", 70 | [](const DynamicGraph &dgraph, NIDType node) { 71 | auto neighbors = dgraph.get_temporal_neighbors(node); 72 | return py::make_tuple(vec2npy(std::get<0>(neighbors)), 73 | vec2npy(std::get<1>(neighbors)), 74 | vec2npy(std::get<2>(neighbors))); 75 | }) 76 | .def("avg_linked_list_length", 77 | [](const DynamicGraph &dgraph) { 78 | return dgraph.avg_linked_list_length(); 79 | }) 80 | .def("get_graph_memory_usage", 81 | [](const DynamicGraph &dgraph) { return dgraph.graph_mem_usage(); }) 82 | .def("get_metadata_memory_usage", [](DynamicGraph &dgraph) { 83 | return dgraph.graph_metadata_mem_usage(); 84 | }); 85 | 86 | py::class_(m, "SamplingResult") 87 | .def("row", 88 | [](const SamplingResult &result) { return vec2npy(result.row); }) 89 | .def("col", 90 | [](const SamplingResult &result) { return vec2npy(result.col); }) 91 | .def("all_nodes", 92 | [](const SamplingResult &result) { 93 | return vec2npy(result.all_nodes); 94 | }) 95 | .def("all_timestamps", 96 | [](const SamplingResult &result) { 97 | return vec2npy(result.all_timestamps); 98 | }) 99 | .def("delta_timestamps", 100 | [](const SamplingResult &result) { 101 | return vec2npy(result.delta_timestamps); 102 | }) 103 | .def("eids", 104 | [](const SamplingResult &result) { return vec2npy(result.eids); }) 105 | .def("num_src_nodes", 106 | [](const SamplingResult &result) { return result.num_src_nodes; }) 107 | .def("num_dst_nodes", 108 | [](const SamplingResult &result) { return result.num_dst_nodes; }); 109 | 110 | py::class_(m, "_TemporalSampler") 111 | .def(py::init &, 112 | SamplingPolicy, uint32_t, float, bool, uint64_t>(), 113 | py::arg("dgraph"), py::arg("fanouts"), py::arg("sampling_policy"), 114 | py::arg("num_snapshots"), py::arg("snapshot_time_window"), 115 | py::arg("prop_time"), py::arg("seed")) 116 | .def("sample", &TemporalSampler::Sample) 117 | .def("sample_layer", &TemporalSampler::SampleLayer); 118 | 119 | py::class_(m, "KVStore") 120 | .def(py::init<>()) 121 | .def("set", &KVStore::set) 122 | .def("get", &KVStore::get) 123 | .def("memory_usage", &KVStore::memory_usage) 124 | .def("fill_zeros", &KVStore::fill_zeros); 125 | } 126 | -------------------------------------------------------------------------------- /gnnflow/csrc/common.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_COMMON_H_ 2 | #define GNNFLOW_COMMON_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace gnnflow { 8 | 9 | // NIDType is the type of node ID. 10 | // TimestampType is the type of timestamp. 11 | // EIDType is the type of edge ID. 12 | // NB: PyTorch does not support converting uint64_t's numpy ndarray to int64_t. 13 | using NIDType = int64_t; 14 | using TimestampType = float; 15 | using EIDType = int64_t; 16 | 17 | constexpr int kMaxFanout = 32; 18 | 19 | constexpr NIDType kInvalidNID = -1; 20 | 21 | constexpr int kNumStreams = 1; 22 | 23 | static constexpr std::size_t kBlockSpaceSize = 24 | (sizeof(NIDType) + sizeof(EIDType) + sizeof(TimestampType)); 25 | 26 | /** 27 | * @brief This POD is used to store the temporal blocks in the graph. 28 | * 29 | * The blocks are stored in a doubly linked list. The first block is the newest 30 | * block. Each block stores the neighbor nodes, timestamps of the edges and IDs 31 | * of edges. The neighbor nodes and corresponding edge ids are sorted by 32 | * timestamps. The block has a maximum capacity and can only store a certain 33 | * number of edges. The block can be moved to a different device. 34 | */ 35 | struct TemporalBlock { 36 | NIDType* dst_nodes; 37 | TimestampType* timestamps; 38 | EIDType* eids; 39 | 40 | std::size_t size; 41 | std::size_t capacity; 42 | 43 | TimestampType start_timestamp; 44 | TimestampType end_timestamp; 45 | 46 | TemporalBlock* prev; 47 | TemporalBlock* next; 48 | }; 49 | 50 | /** @brief This struct is used to store the sampling result. */ 51 | struct SamplingResult { 52 | std::vector row; 53 | std::vector col; 54 | std::vector all_nodes; 55 | std::vector all_timestamps; 56 | std::vector delta_timestamps; 57 | std::vector eids; 58 | std::size_t num_src_nodes; 59 | std::size_t num_dst_nodes; 60 | }; 61 | 62 | struct SamplingRange { 63 | int start_idx; 64 | int end_idx; 65 | }; 66 | 67 | /** 68 | * @brief InsertionPolicy is used to decide how to insert a new temporal block 69 | * into the linked list. 70 | * 71 | * kInsertionPolicyInsert: insert the new block at the head of the list. 72 | * kInsertionPolicyReplace: replace the head block with a larger block. 73 | */ 74 | enum class InsertionPolicy { kInsertionPolicyInsert, kInsertionPolicyReplace }; 75 | 76 | /** 77 | * @brief SamplePolicy is used to decide how to sample the dynamic graph. 78 | * 79 | * kSamplePolicyRecent: sample the most recent edges. 80 | * kSamplePolicyUniform: sample past edges uniformly. 81 | */ 82 | enum class SamplingPolicy { kSamplingPolicyRecent, kSamplingPolicyUniform }; 83 | 84 | enum class MemoryResourceType { 85 | kMemoryResourceTypeCUDA, 86 | kMemoryResourceTypeUnified, 87 | kMemoryResourceTypePinned, 88 | kMemoryResourceTypeShared 89 | }; 90 | 91 | }; // namespace gnnflow 92 | 93 | #endif // GNNFLOW_COMMON_H_ 94 | -------------------------------------------------------------------------------- /gnnflow/csrc/doubly_linked_list.cu: -------------------------------------------------------------------------------- 1 | #include "doubly_linked_list.h" 2 | 3 | namespace gnnflow { 4 | 5 | __device__ void InsertBlockToDoublyLinkedList(DoublyLinkedList* node_table, 6 | NIDType node_id, 7 | TemporalBlock* block) { 8 | auto& list = node_table[node_id]; 9 | if (list.tail == nullptr) { 10 | list.tail = block; 11 | block->prev = nullptr; 12 | block->next = nullptr; 13 | } else { 14 | // append to the tail 15 | list.tail->next = block; 16 | block->prev = list.tail; 17 | block->next = nullptr; 18 | list.tail = block; 19 | } 20 | } 21 | 22 | __device__ void RemoveBlockFromDoublyLinkedList(DoublyLinkedList* node_table, 23 | NIDType node_id, 24 | TemporalBlock* block) { 25 | auto& list = node_table[node_id]; 26 | if (block->prev == nullptr && block->next == nullptr) { 27 | // only one block 28 | list.tail = nullptr; 29 | } else if (block->prev == nullptr) { 30 | // block is the head 31 | block->next->prev = nullptr; 32 | } else if (block->next == nullptr) { 33 | // block is the tail 34 | list.tail = block->prev; 35 | block->prev->next = nullptr; 36 | } else { 37 | // block is in the middle 38 | block->prev->next = block->next; 39 | block->next->prev = block->prev; 40 | } 41 | } 42 | 43 | void InsertBlockToDoublyLinkedList(HostDoublyLinkedList* node_table, 44 | NIDType node_id, TemporalBlock* block) { 45 | auto& list = node_table[node_id]; 46 | if (list.tail == nullptr) { 47 | list.tail = block; 48 | list.head = block; 49 | block->prev = nullptr; 50 | block->next = nullptr; 51 | } else { 52 | // append to the tail 53 | list.tail->next = block; 54 | block->prev = list.tail; 55 | block->next = nullptr; 56 | list.tail = block; 57 | } 58 | list.size++; 59 | } 60 | 61 | void RemoveBlockFromDoublyLinkedList(HostDoublyLinkedList* node_table, 62 | NIDType node_id, TemporalBlock* block) { 63 | auto& list = node_table[node_id]; 64 | if (block->prev == nullptr && block->next == nullptr) { 65 | // only one block 66 | list.head = list.tail = nullptr; 67 | } else if (block->prev == nullptr) { 68 | // block is the head 69 | list.head = block->next; 70 | block->next->prev = nullptr; 71 | } else if (block->next == nullptr) { 72 | // block is the tail 73 | list.tail = block->prev; 74 | block->prev->next = nullptr; 75 | } else { 76 | // block is in the middle 77 | block->prev->next = block->next; 78 | block->next->prev = block->prev; 79 | } 80 | list.size--; 81 | } 82 | 83 | __global__ void InsertBlockToDoublyLinkedListKernel( 84 | DoublyLinkedList* node_table, NIDType node_id, TemporalBlock* block) { 85 | InsertBlockToDoublyLinkedList(node_table, node_id, block); 86 | } 87 | 88 | __global__ void RemoveBlockFromDoublyLinkedListKernel( 89 | DoublyLinkedList* node_table, NIDType node_id, TemporalBlock* next_block) { 90 | RemoveBlockFromDoublyLinkedList(node_table, node_id, next_block); 91 | } 92 | } // namespace gnnflow 93 | -------------------------------------------------------------------------------- /gnnflow/csrc/doubly_linked_list.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_DOUBLY_LINKED_LIST_H_ 2 | #define GNNFLOW_DOUBLY_LINKED_LIST_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "common.h" 8 | #include "logging.h" 9 | 10 | namespace gnnflow { 11 | 12 | /** 13 | * @brief This class is doubly linked list of temporal blocks. 14 | */ 15 | struct DoublyLinkedList { 16 | TemporalBlock* tail; 17 | 18 | __device__ DoublyLinkedList() : tail(nullptr) {} 19 | }; 20 | 21 | struct HostDoublyLinkedList { 22 | TemporalBlock* head; 23 | TemporalBlock* tail; 24 | std::size_t num_edges; 25 | std::size_t num_insertions; 26 | std::size_t size; 27 | 28 | HostDoublyLinkedList() 29 | : head(nullptr), 30 | tail(nullptr), 31 | num_edges(0), 32 | num_insertions(0), 33 | size(0) {} 34 | }; 35 | 36 | __device__ void InsertBlockToDoublyLinkedList(DoublyLinkedList* node_table, 37 | NIDType node_id, 38 | TemporalBlock* block); 39 | 40 | __device__ void RemoveBlockFromDoublyLinkedList(DoublyLinkedList* node_table, 41 | NIDType node_id, 42 | TemporalBlock* block); 43 | 44 | void InsertBlockToDoublyLinkedList(HostDoublyLinkedList* node_table, 45 | NIDType node_id, TemporalBlock* block); 46 | 47 | void RemoveBlockFromDoublyLinkedList(HostDoublyLinkedList* node_table, 48 | NIDType node_id, TemporalBlock* block); 49 | 50 | __global__ void InsertBlockToDoublyLinkedListKernel( 51 | DoublyLinkedList* node_table, NIDType node_id, TemporalBlock* block); 52 | 53 | __global__ void RemoveBlockFromDoublyLinkedListKernel( 54 | DoublyLinkedList* node_table, NIDType node_id, TemporalBlock* next_block); 55 | 56 | } // namespace gnnflow 57 | #endif // GNNFLOW_DOUBLY_LINKED_LIST_H_ 58 | -------------------------------------------------------------------------------- /gnnflow/csrc/dynamic_graph.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_DYNAMIC_GRAPH_H_ 2 | #define GNNFLOW_DYNAMIC_GRAPH_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "common.h" 15 | #include "doubly_linked_list.h" 16 | #include "temporal_block_allocator.h" 17 | 18 | namespace gnnflow { 19 | typedef thrust::device_vector DeviceNodeTable; 20 | typedef std::vector HostNodeTable; 21 | /** 22 | * @brief A dynamic graph is a graph that can be modified at runtime. 23 | * 24 | * The dynamic graph is implemented as block adjacency list. It has a node 25 | * table where each entry is a linked list of temporal blocks. 26 | */ 27 | class DynamicGraph { 28 | public: 29 | /** 30 | * @brief Constructor. 31 | * 32 | * It initialize a temporal block allocator with a memory pool for storing 33 | * edges. The type of the memory resource is determined by the 34 | * `mem_resource_type` parameter. It also creates a device memory pool for 35 | * metadata (i.e., blocks). 36 | * 37 | * @param initial_pool_size The initial size of the memory pool. 38 | * @param maxmium_pool_size The maximum size of the memory pool. 39 | * @param mem_resource_type The type of memory resource for the memory pool. 40 | * @param minimum_block_size The minimum size of the temporal block. 41 | * @param blocks_to_preallocate The number of blocks to preallocate. 42 | * @param insertion_policy The insertion policy for the linked list. 43 | * @param device The device id. 44 | * @param adaptive_block_size Whether to use adaptive block size. 45 | */ 46 | 47 | DynamicGraph(std::size_t initial_pool_size, std::size_t maximum_pool_size, 48 | MemoryResourceType mem_resource_type, 49 | std::size_t minium_block_size, std::size_t blocks_to_preallocate, 50 | InsertionPolicy insertion_policy, int device, 51 | bool adaptive_block_size); 52 | ~DynamicGraph(); 53 | 54 | /** 55 | * @brief Add edges to the graph. 56 | * 57 | * Note that we do not assume that the incoming edges are sorted by 58 | * timestamps. The function will sort them. 59 | * 60 | * @params src_nodes The source nodes of the edges. 61 | * @params dst_nodes The destination nodes of the edges. 62 | * @params timestamps The timestamps of the edges. 63 | * @params eids The edge ids of the edges. 64 | * 65 | */ 66 | void AddEdges(const std::vector& src_nodes, 67 | const std::vector& dst_nodes, 68 | const std::vector& timestamps, 69 | const std::vector& eids); 70 | 71 | /** 72 | * @brief Add nodes to the graph. 73 | * 74 | * @params max_node The maximum node id. 75 | */ 76 | void AddNodes(NIDType max_node); 77 | 78 | /** 79 | * @brief Offload all blocks that are older than the given timestamp. 80 | * 81 | * @params src_node The source node of the blocks. 82 | * @params timestamp The timestamp of the blocks. 83 | * 84 | * @return The number of blocks offloaded. 85 | */ 86 | std::size_t OffloadOldBlocks(TimestampType timestamp, bool to_file = false); 87 | 88 | std::size_t num_nodes() const; 89 | std::size_t num_edges() const; 90 | std::size_t num_src_nodes() const; 91 | 92 | std::vector nodes() const; 93 | std::vector src_nodes() const; 94 | std::vector edges() const; 95 | 96 | NIDType max_node_id() const; 97 | 98 | std::vector out_degree(const std::vector& nodes) const; 99 | 100 | // NB: it is inefficient to call this function every time for each node. Debug 101 | // only. 102 | typedef std::tuple, std::vector, 103 | std::vector> 104 | NodeNeighborTuple; 105 | NodeNeighborTuple get_temporal_neighbors(NIDType node) const; 106 | 107 | const DoublyLinkedList* get_device_node_table() const; 108 | 109 | int device() const { return device_; } 110 | 111 | float avg_linked_list_length() const; 112 | 113 | // NB: does not include metadata. only the edge data. 114 | float graph_mem_usage() const; 115 | 116 | float graph_metadata_mem_usage(); 117 | 118 | private: 119 | void AddEdgesForOneNode(NIDType src_node, 120 | const std::vector& dst_nodes, 121 | const std::vector& timestamps, 122 | const std::vector& eids, 123 | cudaStream_t stream = nullptr); 124 | 125 | void InsertBlock(NIDType node_id, TemporalBlock* block, 126 | cudaStream_t stream = nullptr); 127 | 128 | void RemoveBlock(NIDType node_id, TemporalBlock* block, 129 | cudaStream_t stream = nullptr); 130 | 131 | void SyncBlock(TemporalBlock* block, cudaStream_t stream = nullptr); 132 | 133 | private: 134 | TemporalBlockAllocator allocator_; 135 | 136 | // The device node table. Blocks are allocated in the device memory pool. 137 | DeviceNodeTable d_node_table_; 138 | 139 | // The copy of the device node table in the host. 140 | HostNodeTable h_copy_of_d_node_table_; 141 | 142 | // the host pointer to the block -> the device pointer to the block 143 | std::unordered_map h2d_mapping_; 144 | 145 | InsertionPolicy insertion_policy_; 146 | 147 | std::vector streams_; 148 | 149 | std::size_t max_node_id_; 150 | 151 | std::set nodes_; 152 | std::set src_nodes_; 153 | std::unordered_map edges_; 154 | 155 | std::stack mem_resources_for_metadata_; 156 | 157 | const int device_; 158 | bool adaptive_block_size_; 159 | }; 160 | 161 | } // namespace gnnflow 162 | 163 | #endif // GNNFLOW_DYNAMIC_GRAPH_H_ 164 | -------------------------------------------------------------------------------- /gnnflow/csrc/kvstore.cc: -------------------------------------------------------------------------------- 1 | #include "kvstore.h" 2 | 3 | #include "utils.h" 4 | 5 | namespace gnnflow { 6 | void KVStore::set(const std::vector& keys, const at::Tensor& values) { 7 | std::lock_guard lock(mutex_); 8 | auto num_keys = keys.size(); 9 | for (std::size_t i = 0; i < num_keys; ++i) { 10 | store_[keys[i]] = values[i]; 11 | } 12 | } 13 | 14 | std::vector KVStore::get(const std::vector& keys) { 15 | auto num_keys = keys.size(); 16 | std::vector values(num_keys); 17 | for (size_t i = 0; i < num_keys; ++i) { 18 | values[i] = store_[keys[i]]; 19 | } 20 | return values; 21 | } 22 | 23 | void KVStore::fill_zeros() { 24 | for (auto it = store_.begin(); it != store_.end(); ++it) { 25 | it->second.fill_(0); 26 | } 27 | } 28 | } // namespace gnnflow 29 | -------------------------------------------------------------------------------- /gnnflow/csrc/kvstore.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_KVSTORE_H 2 | #define GNNFLOW_KVSTORE_H 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "absl/container/flat_hash_map.h" 10 | 11 | namespace gnnflow { 12 | 13 | class KVStore { 14 | public: 15 | using Key = unsigned int; 16 | KVStore() = default; 17 | ~KVStore() = default; 18 | 19 | void set(const std::vector& keys, const at::Tensor& values); 20 | 21 | std::vector get(const std::vector& keys); 22 | 23 | void fill_zeros(); 24 | 25 | std::size_t memory_usage() const { 26 | // only count the memory usage of the map 27 | std::size_t total = (sizeof(Key) + sizeof(at::Tensor)) * store_.size(); 28 | return total; 29 | } 30 | 31 | private: 32 | absl::flat_hash_map store_; 33 | std::mutex mutex_; 34 | }; 35 | 36 | } // namespace gnnflow 37 | #endif // GNNFLOW_KVSTORE_H 38 | -------------------------------------------------------------------------------- /gnnflow/csrc/logging.cc: -------------------------------------------------------------------------------- 1 | #include "logging.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace gnnflow { 8 | LogLevel ParseLogLevelStr(const char* env_var_val) { 9 | std::string min_log_level(env_var_val); 10 | std::transform(min_log_level.begin(), min_log_level.end(), 11 | min_log_level.begin(), ::tolower); 12 | if (min_log_level == "trace") { 13 | return LogLevel::TRACE; 14 | } else if (min_log_level == "debug") { 15 | return LogLevel::DEBUG; 16 | } else if (min_log_level == "info") { 17 | return LogLevel::INFO; 18 | } else if (min_log_level == "warning") { 19 | return LogLevel::WARNING; 20 | } else if (min_log_level == "error") { 21 | return LogLevel::ERROR; 22 | } else if (min_log_level == "fatal") { 23 | return LogLevel::FATAL; 24 | } else { 25 | return LogLevel::WARNING; 26 | } 27 | } 28 | 29 | LogLevel MinLogLevelFromEnv() { 30 | const char* env_var_val = getenv("LOGLEVEL"); 31 | if (env_var_val == nullptr) { 32 | return LogLevel::WARNING; 33 | } 34 | return ParseLogLevelStr(env_var_val); 35 | } 36 | 37 | LogMessage::LogMessage(const char* file, int line, LogLevel level) 38 | : file_(file), line_(line), level_(level) {} 39 | 40 | LogMessage::~LogMessage() { 41 | bool use_cout = static_cast(level_) <= static_cast(LogLevel::INFO); 42 | std::ostream& os = use_cout ? std::cout : std::cerr; 43 | 44 | if (level_ >= MinLogLevelFromEnv()) { 45 | os << "[" << LOG_LEVELS[static_cast(level_)] << "] " << file_ << ":" 46 | << line_ << ": " << str() << std::endl; 47 | } 48 | } 49 | 50 | LogMessageFatal::LogMessageFatal(const char* file, int line) 51 | : LogMessage(file, line, LogLevel::FATAL) {} 52 | 53 | LogMessageFatal::~LogMessageFatal() { std::abort(); } 54 | }; // namespace gnnflow 55 | -------------------------------------------------------------------------------- /gnnflow/csrc/logging.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_LOGGING_H_ 2 | #define GNNFLOW_LOGGING_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace gnnflow { 8 | 9 | #define CUDA_CALL(func) \ 10 | { \ 11 | cudaError_t e = (func); \ 12 | if (e != cudaSuccess && e != cudaErrorCudartUnloading) \ 13 | LOG(FATAL) << "CUDA error " << cudaGetErrorString(e) << " at " \ 14 | << __FILE__ << ":" << __LINE__; \ 15 | } 16 | 17 | enum class LogLevel { TRACE, DEBUG, INFO, WARNING, ERROR, FATAL }; 18 | 19 | constexpr char LOG_LEVELS[] = "TDIWEF"; 20 | 21 | class LogMessage : public std::basic_ostringstream { 22 | public: 23 | LogMessage(const char* file, int line, LogLevel level); 24 | virtual ~LogMessage(); 25 | 26 | private: 27 | const char* file_; 28 | int line_; 29 | LogLevel level_; 30 | }; 31 | 32 | class LogMessageFatal : public LogMessage { 33 | public: 34 | LogMessageFatal(const char* file, int line); 35 | ~LogMessageFatal(); 36 | }; 37 | 38 | #define CHECK(x) \ 39 | if (!(x)) gnnflow::LogMessageFatal(__FILE__, __LINE__) << "Check " \ 40 | "failed: " #x 41 | 42 | #define CHECK_EQ(x, y) CHECK((x) == (y)) 43 | #define CHECK_NE(x, y) CHECK((x) != (y)) 44 | #define CHECK_GT(x, y) CHECK((x) > (y)) 45 | #define CHECK_GE(x, y) CHECK((x) >= (y)) 46 | #define CHECK_LT(x, y) CHECK((x) < (y)) 47 | #define CHECK_LE(x, y) CHECK((x) <= (y)) 48 | #define CHECK_NOTNULL(x) CHECK((x) != nullptr) 49 | 50 | #define LOG(level) \ 51 | gnnflow::LogMessage(__FILE__, __LINE__, gnnflow::LogLevel::level) 52 | } // namespace gnnflow 53 | 54 | #endif // GNNFLOW_LOGGING_H_ 55 | -------------------------------------------------------------------------------- /gnnflow/csrc/resource_holder.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_RESOURCE_HOLDER_H 2 | #define GNNFLOW_RESOURCE_HOLDER_H 3 | 4 | #include 5 | #include 6 | 7 | #include "logging.h" 8 | #include "utils.h" 9 | 10 | namespace gnnflow { 11 | 12 | template 13 | class ResourceHolder { 14 | public: 15 | ResourceHolder() = default; 16 | ~ResourceHolder() = default; 17 | 18 | // convert to T 19 | operator T() const& { return resource_; } 20 | 21 | protected: 22 | T resource_; 23 | }; 24 | 25 | template <> 26 | class ResourceHolder { 27 | public: 28 | ResourceHolder() { 29 | CUDA_CALL(cudaStreamCreate(&resource_)); 30 | CUDA_CALL( 31 | cudaStreamCreateWithPriority(&resource_, cudaStreamNonBlocking, 0)) 32 | } 33 | 34 | ~ResourceHolder() { cudaStreamDestroy(resource_); } 35 | 36 | // convert to cudaStream_t 37 | operator cudaStream_t() const& { return resource_; } 38 | 39 | protected: 40 | cudaStream_t resource_; 41 | }; 42 | typedef ResourceHolder StreamHolder; 43 | 44 | template <> 45 | class ResourceHolder { 46 | public: 47 | ResourceHolder() = default; 48 | ResourceHolder(std::size_t size) : resource_(nullptr), size_(size) {} 49 | virtual ~ResourceHolder() = default; 50 | 51 | operator char*() const& { return resource_; } 52 | char* operator+(std::size_t offset) const { return resource_ + offset; } 53 | 54 | std::size_t size() const { return size_; } 55 | 56 | protected: 57 | char* resource_; 58 | std::size_t size_; 59 | }; 60 | typedef ResourceHolder Buffer; 61 | 62 | class GPUResourceHolder : public ResourceHolder { 63 | public: 64 | GPUResourceHolder() = default; 65 | GPUResourceHolder(std::size_t size) : ResourceHolder(size) { 66 | CUDA_CALL(cudaMalloc(&resource_, size)); 67 | } 68 | ~GPUResourceHolder() { cudaFree(resource_); } 69 | }; 70 | typedef GPUResourceHolder GPUBuffer; 71 | 72 | class PinMemoryResourceHolder : public ResourceHolder { 73 | public: 74 | PinMemoryResourceHolder() = default; 75 | PinMemoryResourceHolder(std::size_t size) : ResourceHolder(size) { 76 | CUDA_CALL(cudaMallocHost(&resource_, size)); 77 | } 78 | ~PinMemoryResourceHolder() { cudaFreeHost(resource_); } 79 | }; 80 | typedef PinMemoryResourceHolder PinMemoryBuffer; 81 | 82 | class CuRandStateResourceHolder : public ResourceHolder { 83 | public: 84 | CuRandStateResourceHolder() = default; 85 | CuRandStateResourceHolder(std::size_t num_elements, uint64_t seed) { 86 | CUDA_CALL( 87 | cudaMalloc((void**)&resource_, num_elements * sizeof(curandState_t))); 88 | uint32_t num_threads_per_block = 256; 89 | uint32_t num_blocks = 90 | (num_elements + num_threads_per_block - 1) / num_threads_per_block; 91 | 92 | InitCuRandStates<<>>(resource_, 93 | num_elements, seed); 94 | } 95 | 96 | ~CuRandStateResourceHolder() { cudaFree(resource_); } 97 | 98 | operator curandState_t*() const& { return resource_; } 99 | }; 100 | typedef CuRandStateResourceHolder CuRandStateHolder; 101 | 102 | } // namespace gnnflow 103 | 104 | #endif // GNNFLOW_RESOURCE_HOLDER_H 105 | -------------------------------------------------------------------------------- /gnnflow/csrc/sampling_kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "common.h" 5 | #include "sampling_kernels.h" 6 | #include "utils.h" 7 | 8 | namespace gnnflow { 9 | 10 | __global__ void SampleLayerRecentKernel( 11 | const DoublyLinkedList* node_table, std::size_t num_nodes, bool prop_time, 12 | const NIDType* root_nodes, const TimestampType* root_timestamps, 13 | uint32_t snapshot_idx, uint32_t num_snapshots, 14 | TimestampType snapshot_time_window, uint32_t num_root_nodes, 15 | uint32_t fanout, NIDType* src_nodes, EIDType* eids, 16 | TimestampType* timestamps, TimestampType* delta_timestamps, 17 | uint32_t* num_sampled) { 18 | uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; 19 | if (tid >= num_root_nodes) { 20 | return; 21 | } 22 | 23 | NIDType nid = root_nodes[tid]; 24 | TimestampType root_timestamp = root_timestamps[tid]; 25 | TimestampType start_timestamp, end_timestamp; 26 | if (num_snapshots == 1) { 27 | if (abs(snapshot_time_window) < 1e-6) { 28 | start_timestamp = 0; 29 | } else { 30 | start_timestamp = root_timestamp - snapshot_time_window; 31 | } 32 | end_timestamp = root_timestamp; 33 | } else { 34 | end_timestamp = root_timestamp - 35 | (num_snapshots - snapshot_idx - 1) * snapshot_time_window; 36 | start_timestamp = end_timestamp - snapshot_time_window; 37 | } 38 | 39 | // NB: the tail block is the newest block 40 | auto curr = node_table[nid].tail; 41 | uint32_t offset = tid * fanout; 42 | int start_idx, end_idx; 43 | uint32_t sampled = 0; 44 | while (curr != nullptr && curr->capacity > 0 && sampled < fanout) { 45 | if (end_timestamp < curr->start_timestamp) { 46 | // search in the previous block 47 | curr = curr->prev; 48 | continue; 49 | } 50 | 51 | if (start_timestamp > curr->end_timestamp) { 52 | // no need to search in the previous block 53 | break; 54 | } 55 | 56 | // search in the current block 57 | if (start_timestamp >= curr->start_timestamp && 58 | end_timestamp <= curr->end_timestamp) { 59 | // all edges in the current block 60 | LowerBound(curr->timestamps, curr->size, start_timestamp, &start_idx); 61 | LowerBound(curr->timestamps, curr->size, end_timestamp, &end_idx); 62 | } else if (start_timestamp < curr->start_timestamp && 63 | end_timestamp <= curr->end_timestamp) { 64 | // only the edges before end_timestamp are in the current block 65 | start_idx = 0; 66 | LowerBound(curr->timestamps, curr->size, end_timestamp, &end_idx); 67 | } else if (start_timestamp > curr->start_timestamp && 68 | end_timestamp > curr->end_timestamp) { 69 | // only the edges after start_timestamp are in the current block 70 | LowerBound(curr->timestamps, curr->size, start_timestamp, &start_idx); 71 | end_idx = curr->size; 72 | } else { 73 | // the whole block is in the range 74 | start_idx = 0; 75 | end_idx = curr->size; 76 | } 77 | 78 | // copy the edges to the output 79 | for (int i = end_idx - 1; sampled < fanout && i >= start_idx; --i) { 80 | src_nodes[offset + sampled] = curr->dst_nodes[i]; 81 | eids[offset + sampled] = curr->eids[i]; 82 | timestamps[offset + sampled] = 83 | prop_time ? root_timestamp : curr->timestamps[i]; 84 | delta_timestamps[offset + sampled] = root_timestamp - curr->timestamps[i]; 85 | ++sampled; 86 | } 87 | 88 | curr = curr->prev; 89 | } 90 | 91 | num_sampled[tid] = sampled; 92 | 93 | while (sampled < fanout) { 94 | src_nodes[offset + sampled] = kInvalidNID; 95 | ++sampled; 96 | } 97 | } 98 | 99 | __global__ void SampleLayerUniformKernel( 100 | const DoublyLinkedList* node_table, std::size_t num_nodes, bool prop_time, 101 | curandState_t* rand_states, uint64_t seed, uint32_t offset_per_thread, 102 | const NIDType* root_nodes, const TimestampType* root_timestamps, 103 | uint32_t snapshot_idx, uint32_t num_snapshots, 104 | TimestampType snapshot_time_window, uint32_t num_root_nodes, 105 | uint32_t fanout, NIDType* src_nodes, EIDType* eids, 106 | TimestampType* timestamps, TimestampType* delta_timestamps, 107 | uint32_t* num_sampled) { 108 | uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; 109 | if (tid >= num_root_nodes) { 110 | return; 111 | } 112 | 113 | extern __shared__ SamplingRange ranges[]; 114 | 115 | NIDType nid = root_nodes[tid]; 116 | TimestampType root_timestamp = root_timestamps[tid]; 117 | TimestampType start_timestamp, end_timestamp; 118 | if (num_snapshots == 1) { 119 | if (abs(snapshot_time_window) < 1e-6) { 120 | start_timestamp = 0; 121 | } else { 122 | start_timestamp = root_timestamp - snapshot_time_window; 123 | } 124 | end_timestamp = root_timestamp; 125 | } else { 126 | end_timestamp = root_timestamp - 127 | (num_snapshots - snapshot_idx - 1) * snapshot_time_window; 128 | start_timestamp = end_timestamp - snapshot_time_window; 129 | } 130 | 131 | auto& list = node_table[nid]; 132 | uint32_t num_candidates = 0; 133 | 134 | // NB: the tail block is the newest block 135 | auto curr = list.tail; 136 | int start_idx, end_idx; 137 | int curr_idx = 0; 138 | const int offset_by_thread = offset_per_thread * threadIdx.x; 139 | while (curr != nullptr && curr->capacity > 0) { 140 | if (end_timestamp < curr->start_timestamp) { 141 | // search in the prev block 142 | curr = curr->prev; 143 | curr_idx += 1; 144 | continue; 145 | } 146 | 147 | if (start_timestamp > curr->end_timestamp) { 148 | // no need to search in the prev block 149 | break; 150 | } 151 | 152 | // search in the current block 153 | if (start_timestamp >= curr->start_timestamp && 154 | end_timestamp <= curr->end_timestamp) { 155 | // all edges in the current block 156 | LowerBound(curr->timestamps, curr->size, start_timestamp, &start_idx); 157 | LowerBound(curr->timestamps, curr->size, end_timestamp, &end_idx); 158 | } else if (start_timestamp < curr->start_timestamp && 159 | end_timestamp <= curr->end_timestamp) { 160 | // only the edges before end_timestamp are in the current block 161 | start_idx = 0; 162 | LowerBound(curr->timestamps, curr->size, end_timestamp, &end_idx); 163 | } else if (start_timestamp > curr->start_timestamp && 164 | end_timestamp > curr->end_timestamp) { 165 | // only the edges after start_timestamp are in the current block 166 | LowerBound(curr->timestamps, curr->size, start_timestamp, &start_idx); 167 | end_idx = curr->size; 168 | } else { 169 | // the whole block is in the range 170 | start_idx = 0; 171 | end_idx = curr->size; 172 | } 173 | 174 | if (curr_idx < offset_per_thread) { 175 | ranges[offset_by_thread + curr_idx].start_idx = start_idx; 176 | ranges[offset_by_thread + curr_idx].end_idx = end_idx; 177 | } 178 | 179 | num_candidates += end_idx - start_idx; 180 | curr = curr->prev; 181 | curr_idx += 1; 182 | } 183 | 184 | uint32_t indices[kMaxFanout]; 185 | uint32_t to_sample = min(fanout, num_candidates); 186 | for (uint32_t i = 0; i < to_sample; i++) { 187 | indices[i] = curand(rand_states + tid) % num_candidates; 188 | } 189 | QuickSort(indices, 0, to_sample - 1); 190 | 191 | uint32_t sampled = 0; 192 | uint32_t offset = tid * fanout; 193 | 194 | curr = list.tail; 195 | curr_idx = 0; 196 | uint32_t cumsum = 0; 197 | while (curr != nullptr && curr->capacity > 0) { 198 | if (end_timestamp < curr->start_timestamp) { 199 | // search in the prev block 200 | curr = curr->prev; 201 | curr_idx += 1; 202 | continue; 203 | } 204 | 205 | if (start_timestamp > curr->end_timestamp) { 206 | // no need to search in the prev block 207 | break; 208 | } 209 | 210 | if (curr_idx < offset_per_thread) { 211 | start_idx = ranges[offset_by_thread + curr_idx].start_idx; 212 | end_idx = ranges[offset_by_thread + curr_idx].end_idx; 213 | } else { 214 | // search in the current block 215 | if (start_timestamp >= curr->start_timestamp && 216 | end_timestamp <= curr->end_timestamp) { 217 | // all edges in the current block 218 | LowerBound(curr->timestamps, curr->size, start_timestamp, &start_idx); 219 | LowerBound(curr->timestamps, curr->size, end_timestamp, &end_idx); 220 | } else if (start_timestamp < curr->start_timestamp && 221 | end_timestamp <= curr->end_timestamp) { 222 | // only the edges before end_timestamp are in the current block 223 | start_idx = 0; 224 | LowerBound(curr->timestamps, curr->size, end_timestamp, &end_idx); 225 | } else if (start_timestamp > curr->start_timestamp && 226 | end_timestamp > curr->end_timestamp) { 227 | // only the edges after start_timestamp are in the current block 228 | LowerBound(curr->timestamps, curr->size, start_timestamp, &start_idx); 229 | end_idx = curr->size; 230 | } else { 231 | // the whole block is in the range 232 | start_idx = 0; 233 | end_idx = curr->size; 234 | } 235 | } 236 | 237 | auto idx = indices[sampled] - cumsum; 238 | while (sampled < to_sample && idx < end_idx - start_idx) { 239 | // start from end_idx (newer edges) 240 | src_nodes[offset + sampled] = curr->dst_nodes[end_idx - idx - 1]; 241 | eids[offset + sampled] = curr->eids[end_idx - idx - 1]; 242 | timestamps[offset + sampled] = 243 | prop_time ? root_timestamp : curr->timestamps[end_idx - idx - 1]; 244 | delta_timestamps[offset + sampled] = 245 | root_timestamp - curr->timestamps[end_idx - idx - 1]; 246 | idx = indices[++sampled] - cumsum; 247 | } 248 | 249 | if (sampled >= to_sample) { 250 | break; 251 | } 252 | 253 | cumsum += end_idx - start_idx; 254 | curr = curr->prev; 255 | curr_idx += 1; 256 | } 257 | 258 | num_sampled[tid] = sampled; 259 | 260 | while (sampled < fanout) { 261 | src_nodes[offset + sampled] = kInvalidNID; 262 | ++sampled; 263 | } 264 | } 265 | 266 | } // namespace gnnflow 267 | -------------------------------------------------------------------------------- /gnnflow/csrc/sampling_kernels.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_SAMPLING_KERNELS_H_ 2 | #define GNNFLOW_SAMPLING_KERNELS_H_ 3 | 4 | #include 5 | 6 | #include "doubly_linked_list.h" 7 | 8 | namespace gnnflow { 9 | 10 | __global__ void SampleLayerRecentKernel( 11 | const DoublyLinkedList* node_table, std::size_t num_nodes, bool prop_time, 12 | const NIDType* root_nodes, const TimestampType* root_timestamps, 13 | uint32_t snapshot_idx, uint32_t num_snapshots, 14 | TimestampType snapshot_time_window, uint32_t num_root_nodes, 15 | uint32_t fanout, NIDType* src_nodes, EIDType* eid, 16 | TimestampType* timestamps, TimestampType* delta_timestamps, 17 | uint32_t* num_sampled); 18 | 19 | __global__ void SampleLayerUniformKernel( 20 | const DoublyLinkedList* node_table, std::size_t num_nodes, bool prop_time, 21 | curandState_t* rand_states, uint64_t seed, uint32_t offset_per_thread, 22 | const NIDType* root_nodes, const TimestampType* root_timestamps, 23 | uint32_t snapshot_idx, uint32_t num_snapshots, 24 | TimestampType snapshot_time_window, uint32_t num_root_nodes, 25 | uint32_t fanout, NIDType* src_nodes, EIDType* eids, 26 | TimestampType* timestamps, TimestampType* delta_timestamps, 27 | uint32_t* num_sampled); 28 | 29 | } // namespace gnnflow 30 | 31 | #endif // GNNFLOW_SAMPLING_KERNELS_H_ 32 | -------------------------------------------------------------------------------- /gnnflow/csrc/temporal_block_allocator.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "logging.h" 11 | #include "temporal_block_allocator.h" 12 | #include "utils.h" 13 | 14 | namespace gnnflow { 15 | 16 | TemporalBlockAllocator::TemporalBlockAllocator( 17 | std::size_t initial_pool_size, std::size_t maximum_pool_size, 18 | std::size_t minium_block_size, MemoryResourceType mem_resource_type, 19 | int device) 20 | : mem_resource_type_(mem_resource_type), 21 | minium_block_size_(minium_block_size), 22 | device_(device), 23 | allocated_(0) { 24 | LOG(DEBUG) << "set device to " << device; 25 | CUDA_CALL(cudaSetDevice(device)); 26 | // create a memory pool 27 | switch (mem_resource_type) { 28 | case MemoryResourceType::kMemoryResourceTypeCUDA: { 29 | auto mem_res = new rmm::mr::cuda_memory_resource(); 30 | mem_resources_.push(mem_res); 31 | auto pool_res = 32 | new rmm::mr::pool_memory_resource( 33 | mem_res, initial_pool_size, maximum_pool_size); 34 | mem_resources_.push(pool_res); 35 | break; 36 | } 37 | case MemoryResourceType::kMemoryResourceTypeUnified: { 38 | auto mem_res = new rmm::mr::managed_memory_resource(); 39 | mem_resources_.push(mem_res); 40 | auto pool_res = 41 | new rmm::mr::pool_memory_resource( 42 | mem_res, initial_pool_size, maximum_pool_size); 43 | mem_resources_.push(pool_res); 44 | break; 45 | } 46 | case MemoryResourceType::kMemoryResourceTypePinned: { 47 | auto mem_res = new rmm::mr::pinned_memory_resource(); 48 | mem_resources_.push(mem_res); 49 | auto pool_res = 50 | new rmm::mr::pool_memory_resource( 51 | mem_res, initial_pool_size, maximum_pool_size); 52 | mem_resources_.push(pool_res); 53 | break; 54 | } 55 | case MemoryResourceType::kMemoryResourceTypeShared: { 56 | // NB: device ID is equal to the local rank 57 | auto mem_res = new rmm::mr::shared_memory_resource(device); 58 | mem_resources_.push(mem_res); 59 | auto pool_res = 60 | new rmm::mr::pool_memory_resource( 61 | mem_res, initial_pool_size, maximum_pool_size); 62 | mem_resources_.push(pool_res); 63 | break; 64 | } 65 | } 66 | } 67 | 68 | TemporalBlockAllocator::~TemporalBlockAllocator() { 69 | for (auto &block : blocks_) { 70 | DeallocateInternal(block); 71 | delete block; 72 | } 73 | 74 | // release the memory pool 75 | while (!mem_resources_.empty()) { 76 | delete mem_resources_.top(); 77 | mem_resources_.pop(); 78 | } 79 | 80 | blocks_.clear(); 81 | } 82 | 83 | std::size_t TemporalBlockAllocator::AlignUp(std::size_t size) { 84 | if (size < minium_block_size_) { 85 | return minium_block_size_; 86 | } 87 | return size; 88 | } 89 | 90 | TemporalBlock *TemporalBlockAllocator::Allocate(std::size_t size) { 91 | auto block = new TemporalBlock(); 92 | 93 | try { 94 | AllocateInternal(block, size); 95 | } catch (rmm::bad_alloc &) { 96 | // failed to allocate memory 97 | DeallocateInternal(block); 98 | 99 | LOG(FATAL) << "Failed to allocate memory for temporal block of size " 100 | << size; 101 | } 102 | 103 | { 104 | std::lock_guard lock(mutex_); 105 | blocks_.push_back(block); 106 | } 107 | return block; 108 | } 109 | 110 | void TemporalBlockAllocator::Deallocate(TemporalBlock *block) { 111 | CHECK_NOTNULL(block); 112 | DeallocateInternal(block); 113 | 114 | { 115 | std::lock_guard lock(mutex_); 116 | blocks_.erase(std::remove(blocks_.begin(), blocks_.end(), block)); 117 | } 118 | 119 | delete block; 120 | } 121 | 122 | void TemporalBlockAllocator::Reallocate(TemporalBlock *block, std::size_t size, 123 | cudaStream_t stream) { 124 | CHECK_NOTNULL(block); 125 | 126 | TemporalBlock tmp; 127 | AllocateInternal(&tmp, size); 128 | CopyTemporalBlock(block, &tmp, device_, stream); 129 | DeallocateInternal(block); 130 | 131 | *block = tmp; 132 | } 133 | 134 | void TemporalBlockAllocator::AllocateInternal( 135 | TemporalBlock *block, std::size_t size) noexcept(false) { 136 | std::size_t capacity = AlignUp(size); 137 | 138 | block->size = 0; // empty block 139 | block->capacity = capacity; 140 | block->start_timestamp = std::numeric_limits::max(); 141 | block->end_timestamp = 0; 142 | block->prev = nullptr; 143 | block->next = nullptr; 144 | 145 | // allocate memory for the block 146 | // NB: rmm is thread-safe 147 | auto mr = mem_resources_.top(); 148 | block->dst_nodes = 149 | static_cast(mr->allocate(capacity * sizeof(NIDType))); 150 | block->timestamps = static_cast( 151 | mr->allocate(capacity * sizeof(TimestampType))); 152 | block->eids = 153 | static_cast(mr->allocate(capacity * sizeof(EIDType))); 154 | 155 | allocated_ += 156 | capacity * (sizeof(NIDType) + sizeof(TimestampType) + sizeof(EIDType)); 157 | } 158 | 159 | void TemporalBlockAllocator::DeallocateInternal(TemporalBlock *block) { 160 | auto mr = mem_resources_.top(); 161 | if (block->dst_nodes != nullptr) { 162 | mr->deallocate(block->dst_nodes, block->capacity * sizeof(NIDType)); 163 | block->dst_nodes = nullptr; 164 | allocated_ -= block->capacity * sizeof(NIDType); 165 | } 166 | if (block->timestamps != nullptr) { 167 | mr->deallocate(block->timestamps, block->capacity * sizeof(TimestampType)); 168 | block->timestamps = nullptr; 169 | allocated_ -= block->capacity * sizeof(TimestampType); 170 | } 171 | if (block->eids != nullptr) { 172 | mr->deallocate(block->eids, block->capacity * sizeof(EIDType)); 173 | block->eids = nullptr; 174 | allocated_ -= block->capacity * sizeof(EIDType); 175 | } 176 | 177 | block->size = 0; 178 | block->capacity = 0; 179 | // NB: we don't reset the timestamps and prev/next 180 | } 181 | 182 | void TemporalBlockAllocator::SaveToFile(TemporalBlock *block, 183 | NIDType src_node) { 184 | if (mem_resource_type_ != MemoryResourceType::kMemoryResourceTypePinned && 185 | mem_resource_type_ != MemoryResourceType::kMemoryResourceTypeShared) { 186 | LOG(FATAL) << "Only pinned and shared memory resources are supported"; 187 | } 188 | 189 | // NB: only first rank saves the temporal block 190 | if (device_ == 0) { 191 | std::string file_name = "temporal_block_" + std::to_string(src_node) + "-" + 192 | std::to_string(num_saved_blocks_[src_node]) + 193 | ".bin"; 194 | std::ofstream file(file_name, std::ios::out | std::ios::binary); 195 | file.write(reinterpret_cast(&block->size), sizeof(block->size)); 196 | file.write(reinterpret_cast(&block->capacity), 197 | sizeof(block->capacity)); 198 | file.write(reinterpret_cast(&block->start_timestamp), 199 | sizeof(block->start_timestamp)); 200 | file.write(reinterpret_cast(&block->end_timestamp), 201 | sizeof(block->end_timestamp)); 202 | file.write(reinterpret_cast(block->dst_nodes), 203 | sizeof(NIDType) * block->size); 204 | file.write(reinterpret_cast(block->timestamps), 205 | sizeof(TimestampType) * block->size); 206 | file.write(reinterpret_cast(block->eids), 207 | sizeof(EIDType) * block->size); 208 | file.write(reinterpret_cast(&block->prev), sizeof(block->prev)); 209 | file.write(reinterpret_cast(&block->next), sizeof(block->next)); 210 | file.close(); 211 | 212 | saved_blocks_[block] = file_name; 213 | num_saved_blocks_[src_node]++; 214 | 215 | LOG(INFO) << "Temporal block saved to " << file_name; 216 | } 217 | 218 | // NB: all ranks need to deallocate the temporal block (but only first rank 219 | // release the memory) 220 | DeallocateInternal(block); 221 | } 222 | 223 | void TemporalBlockAllocator::ReadFromFile(TemporalBlock *block, 224 | NIDType src_node) { 225 | if (mem_resource_type_ != MemoryResourceType::kMemoryResourceTypePinned) { 226 | LOG(FATAL) << "Only pinned memory resources are supported"; 227 | // NB: shared memory resources are not supported because we need to know the 228 | // offset of the temporal block in the shared memory 229 | } 230 | CHECK_EQ(device_, 0) << "Only first rank can read temporal blocks from file"; 231 | 232 | std::string file_name = saved_blocks_[block]; 233 | std::ifstream file(file_name, std::ios::in | std::ios::binary); 234 | file.read(reinterpret_cast(&block->size), sizeof(block->size)); 235 | file.read(reinterpret_cast(&block->capacity), 236 | sizeof(block->capacity)); 237 | AllocateInternal(block, block->capacity); 238 | file.read(reinterpret_cast(&block->start_timestamp), 239 | sizeof(block->start_timestamp)); 240 | file.read(reinterpret_cast(&block->end_timestamp), 241 | sizeof(block->end_timestamp)); 242 | file.read(reinterpret_cast(block->dst_nodes), 243 | sizeof(NIDType) * block->size); 244 | file.read(reinterpret_cast(block->timestamps), 245 | sizeof(TimestampType) * block->size); 246 | file.read(reinterpret_cast(block->eids), 247 | sizeof(EIDType) * block->size); 248 | file.read(reinterpret_cast(&block->prev), sizeof(block->prev)); 249 | file.read(reinterpret_cast(&block->next), sizeof(block->next)); 250 | file.close(); 251 | 252 | saved_blocks_.erase(block); 253 | num_saved_blocks_[src_node]--; 254 | 255 | LOG(INFO) << "Temporal block read from " << file_name; 256 | } 257 | } // namespace gnnflow 258 | -------------------------------------------------------------------------------- /gnnflow/csrc/temporal_block_allocator.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_TEMPORAL_BLOCK_ALLOCATOR_H_ 2 | #define GNNFLOW_TEMPORAL_BLOCK_ALLOCATOR_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "common.h" 10 | 11 | namespace gnnflow { 12 | /** 13 | * @brief This class implements a thread-safe memory resource that allocates 14 | * temporal blocks. 15 | * 16 | * The allocated blocks are in the host memory. But the edges are stored in the 17 | * device memory or managed memory or pinned memory, depending on the memory 18 | * resource type. 19 | */ 20 | class TemporalBlockAllocator { 21 | public: 22 | /** 23 | * @brief Constructor. 24 | * 25 | * It creates a memory pool. 26 | * 27 | * @param initial_pool_size The initial size of the memory pool. 28 | * @param maxmium_pool_size The maximum size of the memory pool 29 | * @param minimum_block_size The minimum size of the temporal block. 30 | * @param MemoryResourceType The type of memory resource. 31 | * @param device The device id. 32 | */ 33 | TemporalBlockAllocator(std::size_t initial_pool_size, 34 | std::size_t maximum_pool_size, 35 | std::size_t minimum_block_size, 36 | MemoryResourceType mem_resource_type, int device); 37 | 38 | /** 39 | * @brief Destructor. 40 | * 41 | * It frees all the temporal blocks. 42 | */ 43 | ~TemporalBlockAllocator(); 44 | 45 | /** 46 | * @brief Allocate a temporal block. 47 | * 48 | * NB: the block itself is in the host memory. 49 | * 50 | * @param size The size of the temporal block. 51 | * 52 | * @return A host pointer to the temporal block. 53 | */ 54 | TemporalBlock* Allocate(std::size_t size); 55 | 56 | /** 57 | * @brief Deallocate a temporal block. 58 | * 59 | * @param block The temporal block to deallocate. It must be in the host 60 | * memory. 61 | */ 62 | void Deallocate(TemporalBlock* block); 63 | 64 | /** 65 | * @brief Reallocate a temporal block. 66 | * 67 | * NB: We only change the content of the temporal block. The host pointer to 68 | * the temporal block is not changed. 69 | * 70 | * @param block The temporal block to reallocate. It must be in the host 71 | * memory. 72 | * @param size The new size of the temporal block. 73 | * @param stream The stream to use. If nullptr, the default stream is used. 74 | */ 75 | void Reallocate(TemporalBlock* block, std::size_t size, 76 | cudaStream_t stream = nullptr); 77 | 78 | void SaveToFile(TemporalBlock* block, NIDType src_node); 79 | void ReadFromFile(TemporalBlock* block, NIDType src_node); 80 | 81 | std::size_t get_total_memory_usage() const { return allocated_; } 82 | 83 | private: 84 | std::size_t AlignUp(std::size_t size); 85 | 86 | void AllocateInternal(TemporalBlock* block, std::size_t size) noexcept(false); 87 | 88 | void DeallocateInternal(TemporalBlock* block); 89 | 90 | std::vector blocks_; 91 | 92 | std::size_t minium_block_size_; 93 | MemoryResourceType mem_resource_type_; 94 | std::stack mem_resources_; 95 | 96 | std::unordered_map saved_blocks_; 97 | std::unordered_map num_saved_blocks_; 98 | 99 | std::mutex mutex_; 100 | 101 | const int device_; 102 | 103 | std::size_t allocated_; 104 | }; 105 | 106 | } // namespace gnnflow 107 | 108 | #endif // GNNFLOW_TEMPORAL_BLOCK_ALLOCATOR_H_ 109 | -------------------------------------------------------------------------------- /gnnflow/csrc/temporal_sampler.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_TEMPORAL_SAMPLER_H_ 2 | #define GNNFLOW_TEMPORAL_SAMPLER_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "common.h" 13 | #include "dynamic_graph.h" 14 | #include "resource_holder.h" 15 | 16 | namespace gnnflow { 17 | 18 | class TemporalSampler { 19 | public: 20 | TemporalSampler(const DynamicGraph& graph, 21 | const std::vector& fanouts, 22 | SamplingPolicy sample_policy, uint32_t num_snapshots = 1, 23 | float snapshot_time_window = 0.0f, bool prop_time = false, 24 | uint64_t seed = 1234); 25 | ~TemporalSampler() = default; 26 | 27 | std::vector> Sample( 28 | const std::vector& dst_nodes, 29 | const std::vector& dst_timestamps); 30 | 31 | // NB: this function should handle input with dynamic length (i.e., 32 | // `dst_nodes` and `dst_timestamps` can have dynamic lengths every time). Make 33 | // sure to re-allocate the buffer if the input size is larger than the current 34 | // buffer size. 35 | SamplingResult SampleLayer(const std::vector& dst_nodes, 36 | const std::vector& dst_timestamps, 37 | uint32_t layer, uint32_t snapshot); 38 | 39 | private: 40 | constexpr static std::size_t kPerNodeInputBufferSize = 41 | sizeof(NIDType) + sizeof(TimestampType); 42 | 43 | constexpr static std::size_t kPerNodeOutputBufferSize = 44 | sizeof(NIDType) + sizeof(TimestampType) + sizeof(EIDType) + 45 | sizeof(TimestampType) + sizeof(uint32_t); 46 | 47 | typedef std::tuple InputBufferTuple; 48 | InputBufferTuple GetInputBufferTuple(const Buffer& buffer, 49 | std::size_t num_root_nodes) const; 50 | 51 | typedef std::tuple 53 | OutputBufferTuple; 54 | OutputBufferTuple GetOutputBufferTuple( 55 | const Buffer& buffer, std::size_t num_root_nodes, 56 | std::size_t maximum_sampled_nodes) const; 57 | 58 | void InitBufferIfNeeded(std::size_t num_root_nodes, 59 | std::size_t maximum_sampled_nodes); 60 | 61 | private: 62 | const DynamicGraph& graph_; // sampling does not modify the graph 63 | std::vector fanouts_; 64 | SamplingPolicy sampling_policy_; 65 | uint32_t num_snapshots_; 66 | float snapshot_time_window_; 67 | bool prop_time_; 68 | uint32_t num_layers_; 69 | uint64_t seed_; 70 | std::size_t shared_memory_size_; 71 | int device_; 72 | 73 | std::unique_ptr stream_holders_; 74 | std::unique_ptr cpu_buffer_; 75 | std::unique_ptr gpu_input_buffer_; 76 | std::unique_ptr gpu_output_buffer_; 77 | std::unique_ptr rand_states_; 78 | 79 | std::size_t maximum_num_root_nodes_; 80 | std::size_t maximum_sampled_nodes_; 81 | }; 82 | 83 | } // namespace gnnflow 84 | 85 | #endif // GNNFLOW_TEMPORAL_SAMPLER_H_ 86 | -------------------------------------------------------------------------------- /gnnflow/csrc/utils.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "common.h" 5 | #include "logging.h" 6 | #include "utils.h" 7 | 8 | namespace gnnflow { 9 | void CopyTemporalBlock(TemporalBlock* src, TemporalBlock* dst, int device, 10 | cudaStream_t stream) { 11 | CHECK_NOTNULL(src); 12 | CHECK_NOTNULL(dst); 13 | CHECK_GE(dst->capacity, src->capacity); 14 | 15 | if (device == 0) { 16 | CUDA_CALL(cudaMemcpyAsync(dst->dst_nodes, src->dst_nodes, 17 | src->size * sizeof(NIDType), cudaMemcpyDefault, 18 | stream)); 19 | 20 | CUDA_CALL(cudaMemcpyAsync(dst->timestamps, src->timestamps, 21 | src->size * sizeof(TimestampType), 22 | cudaMemcpyDefault, stream)); 23 | 24 | CUDA_CALL(cudaMemcpyAsync(dst->eids, src->eids, src->size * sizeof(EIDType), 25 | cudaMemcpyDefault, stream)); 26 | } 27 | dst->size = src->size; 28 | dst->start_timestamp = src->start_timestamp; 29 | dst->end_timestamp = src->end_timestamp; 30 | dst->next = src->next; 31 | } 32 | 33 | void CopyEdgesToBlock(TemporalBlock* block, 34 | const std::vector& dst_nodes, 35 | const std::vector& timestamps, 36 | const std::vector& eids, std::size_t start_idx, 37 | std::size_t num_edges, int device, cudaStream_t stream) { 38 | CHECK_NOTNULL(block); 39 | CHECK_EQ(dst_nodes.size(), timestamps.size()); 40 | CHECK_EQ(eids.size(), timestamps.size()); 41 | CHECK_LE(block->size + num_edges, block->capacity); 42 | // NB: we assume that the incoming edges are newer than the existing ones. 43 | CHECK_LE(block->end_timestamp, timestamps[start_idx + num_edges - 1]); 44 | 45 | if (device == 0) { 46 | CUDA_CALL(cudaMemcpyAsync( 47 | block->dst_nodes + block->size, &dst_nodes[start_idx], 48 | sizeof(NIDType) * num_edges, cudaMemcpyDefault, stream)); 49 | 50 | CUDA_CALL(cudaMemcpyAsync( 51 | block->timestamps + block->size, ×tamps[start_idx], 52 | sizeof(TimestampType) * num_edges, cudaMemcpyDefault, stream)); 53 | 54 | CUDA_CALL(cudaMemcpyAsync(block->eids + block->size, &eids[start_idx], 55 | sizeof(EIDType) * num_edges, cudaMemcpyDefault, 56 | stream)); 57 | } 58 | block->size += num_edges; 59 | 60 | block->start_timestamp = 61 | std::min(block->start_timestamp, timestamps[start_idx]); 62 | block->end_timestamp = timestamps[start_idx + num_edges - 1]; 63 | } 64 | 65 | std::size_t GetSharedMemoryMaxSize() { 66 | std::size_t max_size = 0; 67 | cudaDeviceProp prop; 68 | cudaGetDeviceProperties(&prop, 0); 69 | max_size = prop.sharedMemPerBlock; 70 | return max_size; 71 | } 72 | 73 | void Copy(void* dst, const void* src, std::size_t size) { 74 | auto in = (long long*)src; 75 | auto out = (long long*)dst; 76 | constexpr std::size_t kUnitSize = sizeof(long long); 77 | #pragma omp parallel for simd num_threads(4) schedule(static) 78 | for (size_t i = 0; i < size / kUnitSize; ++i) { 79 | out[i] = in[i]; 80 | } 81 | 82 | if (size % kUnitSize) { 83 | std::memcpy(out + size / kUnitSize, in + size / kUnitSize, 84 | size % kUnitSize); 85 | } 86 | } 87 | 88 | __global__ void InitCuRandStates(curandState_t* states, 89 | std::size_t num_elements, uint64_t seed) { 90 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 91 | if (tid < num_elements) { 92 | curand_init(seed, tid, 0, &states[tid]); 93 | } 94 | } 95 | 96 | __host__ __device__ void LowerBound(TimestampType* timestamps, int num_edges, 97 | TimestampType timestamp, int* idx) { 98 | int left = 0; 99 | int right = num_edges; 100 | while (left < right) { 101 | int mid = (left + right) / 2; 102 | if (timestamps[mid] < timestamp) { 103 | left = mid + 1; 104 | } else { 105 | right = mid; 106 | } 107 | } 108 | *idx = left; 109 | } 110 | 111 | template 112 | __host__ __device__ void inline swap(T& a, T& b) { 113 | T c(a); 114 | a = b; 115 | b = c; 116 | } 117 | 118 | __host__ __device__ void QuickSort(uint32_t* indices, int lo, int hi) { 119 | if (lo >= hi || lo < 0 || hi < 0) return; 120 | 121 | uint32_t pivot = indices[hi]; 122 | int i = lo - 1; 123 | for (int j = lo; j < hi; ++j) { 124 | if (indices[j] < pivot) { 125 | swap(indices[++i], indices[j]); 126 | } 127 | } 128 | swap(indices[++i], indices[hi]); 129 | 130 | QuickSort(indices, lo, i - 1); 131 | QuickSort(indices, i + 1, hi); 132 | } 133 | 134 | } // namespace gnnflow 135 | -------------------------------------------------------------------------------- /gnnflow/csrc/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef GNNFLOW_UTILS_H_ 2 | #define GNNFLOW_UTILS_H_ 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "common.h" 12 | 13 | namespace gnnflow { 14 | 15 | template 16 | std::vector stable_sort_indices(const std::vector& v) { 17 | // initialize original index locations 18 | std::vector idx(v.size()); 19 | std::iota(idx.begin(), idx.end(), 0); 20 | 21 | // sort indexes based on comparing values in v 22 | std::stable_sort( 23 | idx.begin(), idx.end(), 24 | [&v](std::size_t i1, std::size_t i2) { return v[i1] < v[i2]; }); 25 | 26 | return idx; 27 | } 28 | 29 | template 30 | std::vector sort_vector(const std::vector& v, 31 | const std::vector& idx) { 32 | std::vector sorted_v; 33 | sorted_v.reserve(v.size()); 34 | for (auto i : idx) { 35 | sorted_v.emplace_back(v[i]); 36 | } 37 | return sorted_v; 38 | } 39 | 40 | /** 41 | * @brief Copy a temporal block on the GPU to another block. 42 | * 43 | * The destination block should have a size greater than or equal to the 44 | * source block. It assumes that the source block is on the GPU. But the 45 | * destination block can be on the CPU or on the GPU. 46 | * 47 | * @param dst The destination temporal block. 48 | * @param src The source temporal block. 49 | * @param device The device id. 50 | * @param stream The CUDA stream. 51 | */ 52 | void CopyTemporalBlock(TemporalBlock* src, TemporalBlock* dst, int device, 53 | cudaStream_t stream = nullptr); 54 | 55 | /** 56 | * @brief Copy edges on the CPU to the block on the GPU. 57 | * 58 | * The destination block should have a size greater than or equal to the 59 | * incoming edges. 60 | * 61 | * @param block The destination temporal block. 62 | * @param dst_nodes The destination nodes. 63 | * @param timestamps The timestamps of the incoming edges. 64 | * @param eids The ids of the incoming edges. 65 | * @param start_idx The start index of the incoming edges. 66 | * @param num_edges The number of incoming edges. 67 | * @param device The device id. 68 | * @param stream The CUDA stream. 69 | */ 70 | void CopyEdgesToBlock(TemporalBlock* block, 71 | const std::vector& dst_nodes, 72 | const std::vector& timestamps, 73 | const std::vector& eids, std::size_t start_idx, 74 | std::size_t num_edges, int device, 75 | cudaStream_t stream = nullptr); 76 | 77 | std::size_t GetSharedMemoryMaxSize(); 78 | 79 | void Copy(void* dst, const void* src, std::size_t size); 80 | 81 | __global__ void InitCuRandStates(curandState_t* state, std::size_t num_elements, 82 | uint64_t seed); 83 | 84 | __host__ __device__ void LowerBound(TimestampType* timestamps, int num_edges, 85 | TimestampType timestamp, int* idx); 86 | 87 | __host__ __device__ void QuickSort(uint32_t* indices, int lo, int hi); 88 | } // namespace gnnflow 89 | 90 | #endif // GNNFLOW_UTILS_H_ 91 | -------------------------------------------------------------------------------- /gnnflow/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist_context import initialize, dispatch_full_dataset 2 | -------------------------------------------------------------------------------- /gnnflow/distributed/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SamplingResultTorch: 5 | def __init__(self): 6 | self.row: torch.Tensor = None 7 | self.col: torch.Tensor = None 8 | self.num_src_nodes: int = None 9 | self.num_dst_nodes: int = None 10 | self.all_nodes: torch.Tensor = None 11 | self.all_timestamps: torch.Tensor = None 12 | self.delta_timestamps: torch.Tensor = None 13 | self.eids: torch.Tensor = None 14 | 15 | def __getstate__(self): 16 | return self.__dict__ 17 | 18 | def __setstate__(self, state): 19 | self.__dict__.update(state) 20 | 21 | 22 | # let pickle know how to serialize the SamplingResultType 23 | globals()['SamplingResultTorch'] = SamplingResultTorch 24 | -------------------------------------------------------------------------------- /gnnflow/distributed/dist_context.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import psutil 6 | import torch 7 | import torch.distributed 8 | import torch.distributed.rpc as rpc 9 | from tqdm import tqdm 10 | 11 | import gnnflow.distributed.graph_services as graph_services 12 | from gnnflow.distributed.dispatcher import get_dispatcher 13 | from gnnflow.distributed.kvstore import KVStoreServer 14 | from gnnflow.utils import load_dataset_in_chunks, load_feat 15 | 16 | 17 | def initialize(rank: int, world_size: int, partition_strategy: str, 18 | num_partitions: int, data_name: str, dim_memory: int): 19 | """ 20 | Initialize the distributed environment. 21 | 22 | Args: 23 | rank (int): The rank of the current process. 24 | world_size (int): The number of processes participating in the job. 25 | num_partitions (int): The number of partitions to split the dataset into. 26 | data_name (str): the dataset name of the dataset for loading features. 27 | dim_memory (int): the dimension of memory 28 | """ 29 | # NB: disable IB according to https://github.com/pytorch/pytorch/issues/86962 30 | rpc.init_rpc("worker%d" % rank, rank=rank, world_size=world_size, 31 | rpc_backend_options=rpc.TensorPipeRpcBackendOptions( 32 | rpc_timeout=180000, 33 | num_worker_threads=32, 34 | _transports=["shm", "uv"], 35 | _channels=["cma", "mpt_uv", "basic", "cuda_xth", "cuda_ipc", "cuda_basic"])) 36 | logging.info("Rank %d: Initialized RPC.", rank) 37 | 38 | local_rank = int(os.environ["LOCAL_RANK"]) 39 | 40 | # Initialize the KVStore. 41 | if local_rank == 0: 42 | node_feats, edge_feats = load_feat(data_name, memmap=True) 43 | dim_node = 0 if node_feats is None else node_feats.shape[1] 44 | dim_edge = 0 if edge_feats is None else edge_feats.shape[1] 45 | graph_services.set_kvstore_server(KVStoreServer( 46 | node_feats, edge_feats, dim_memory, dim_edge)) 47 | 48 | if rank == 0: 49 | dispatcher = get_dispatcher( 50 | partition_strategy, num_partitions, dim_node > 0, dim_edge > 0, dim_memory > 0, data_name) 51 | dispatcher.broadcast_node_edge_dim(dim_node, dim_edge) 52 | 53 | torch.distributed.barrier() 54 | if rank == 0: 55 | logging.info("initialized done") 56 | 57 | 58 | def dispatch_full_dataset(rank: int, data_name: str, 59 | initial_ingestion_batch_size: int, ingestion_batch_size: int): 60 | start = time.time() 61 | if rank == 0: 62 | dispatcher = get_dispatcher() 63 | # read csv in chunks 64 | df_iterator = load_dataset_in_chunks( 65 | data_name, chunksize=initial_ingestion_batch_size) 66 | 67 | t = tqdm() 68 | # ingest the first chunk 69 | for i, dataset in enumerate(df_iterator): 70 | dataset.rename(columns={'Unnamed: 0': 'eid'}, inplace=True) 71 | if i > 0: 72 | for i in range(0, initial_ingestion_batch_size, ingestion_batch_size): 73 | dataset_chunk = dataset.iloc[i:i + ingestion_batch_size] 74 | dispatcher.partition_graph(dataset_chunk, False) 75 | t.update(len(dataset_chunk)) 76 | else: 77 | dispatcher.partition_graph(dataset, False) 78 | t.update(initial_ingestion_batch_size) 79 | del dataset 80 | 81 | t.close() 82 | logging.info("Rank 0: Ingestion edges done in %.2fs.", 83 | time.time() - start) 84 | dispatcher.dispatch_node_memory() 85 | logging.info("Rank 0: Dispatch node memory done in %.2fs.", 86 | time.time() - start) 87 | dispatcher.broadcast_rand_sampler() 88 | logging.info("Rank 0: Broadcast rand sampler done in %.2fs.", 89 | time.time() - start) 90 | 91 | # check 92 | torch.distributed.barrier() 93 | if rank == 0: 94 | logging.info("Rank %d: Ingested full dataset in %f seconds.", rank, 95 | time.time() - start) 96 | 97 | logging.info("Rank %d: Number of vertices: %d, number of edges: %d", 98 | rank, graph_services.num_vertices(), graph_services.num_edges()) 99 | logging.info("Rank %d: partition table shape: %s", 100 | rank, str(graph_services.get_partition_table().shape)) 101 | 102 | # save partition table 103 | dgraph = graph_services.get_dgraph() 104 | logging.info("Rank %d: local number of vertices: %d, number of edges: %d", 105 | rank, dgraph._dgraph.num_vertices(), dgraph._dgraph.num_edges()) 106 | mem = psutil.virtual_memory().percent 107 | logging.info("build graph done memory usage: {}".format(mem)) 108 | -------------------------------------------------------------------------------- /gnnflow/distributed/dist_graph.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import time 4 | from queue import Queue 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from gnnflow import DynamicGraph 10 | from gnnflow.distributed.utils import HandleManager 11 | 12 | 13 | class DistributedDynamicGraph: 14 | """ 15 | Distributed dynamic graph. 16 | """ 17 | 18 | def __init__(self, dgraph: DynamicGraph): 19 | """ 20 | Initialize the distributed dynamic graph. 21 | 22 | Args: 23 | dgraph (DynamicGraph): The dynamic graph instance. 24 | """ 25 | self._dgraph = dgraph 26 | self._num_partitions = None 27 | self._partition_table = None 28 | self._num_vertices = 0 29 | self._num_edges = 0 30 | self._max_vertex_id = 0 31 | 32 | self._handle_manager = HandleManager() 33 | self._handles = set() 34 | self._add_edges_thread = threading.Thread(target=self._add_edges_loop) 35 | self._add_edges_queue = Queue() 36 | self._add_edges_thread.start() 37 | 38 | def shutdown(self): 39 | logging.info('DistributedDynamicGraph shutdown') 40 | self._add_edges_queue.put((None, None, None, None, None)) 41 | self._add_edges_thread.join() 42 | 43 | def _add_edges_loop(self): 44 | while True: 45 | while not self._add_edges_queue.empty(): 46 | source_vertices, target_vertices, timestamps, eids, handle = self._add_edges_queue.get() 47 | if handle is None: 48 | return 49 | 50 | self._dgraph.add_edges( 51 | source_vertices, target_vertices, timestamps, eids) 52 | 53 | self._handle_manager.mark_done(handle) 54 | 55 | time.sleep(0.001) 56 | 57 | def enqueue_add_edges_task(self, source_vertices: np.ndarray, target_vertices: np.ndarray, 58 | timestamps: np.ndarray, eids: np.ndarray): 59 | handle = self._handle_manager.allocate_handle() 60 | self._add_edges_queue.put( 61 | (source_vertices, target_vertices, timestamps, eids, handle)) 62 | self._handles.add(handle) 63 | 64 | def poll(self, handle): 65 | return self._handle_manager.poll(handle) 66 | 67 | def wait_for_all_updates_to_finish(self): 68 | for handle in self._handles.copy(): 69 | if self.poll(handle): 70 | self._handles.remove(handle) 71 | 72 | def num_vertices(self) -> int: 73 | """ 74 | Get the number of vertices in the dynamic graph. 75 | Returns: 76 | int: The number of vertices. 77 | """ 78 | return self._num_vertices 79 | 80 | def num_edges(self) -> int: 81 | """ 82 | Get the number of edges in the dynamic graph. 83 | 84 | Returns: 85 | int: The number of edges. 86 | """ 87 | return self._num_edges 88 | 89 | def num_source_vertices(self) -> int: 90 | return self._dgraph.num_source_vertices() 91 | 92 | def nodes(self) -> np.ndarray: 93 | """ 94 | Return the nodes of the graph. 95 | """ 96 | return self._dgraph.nodes() 97 | 98 | def src_nodes(self) -> np.ndarray: 99 | """ 100 | Return the source nodes of the graph. 101 | """ 102 | return self._dgraph.src_nodes() 103 | 104 | def edges(self) -> np.ndarray: 105 | """ 106 | Return the edges of the graph. 107 | """ 108 | return self._dgraph.edges() 109 | 110 | def set_num_vertices(self, num_vertices: int): 111 | """ 112 | Set the number of vertices in the dynamic graph. 113 | 114 | Args: 115 | num_vertices (int): The number of vertices. 116 | """ 117 | self._num_vertices = num_vertices 118 | 119 | def set_num_edges(self, num_edges: int): 120 | """ 121 | Set the number of edges in the dynamic graph. 122 | 123 | Args: 124 | num_edges (int): The number of edges. 125 | """ 126 | self._num_edges = num_edges 127 | 128 | def max_vertex_id(self) -> int: 129 | return self._max_vertex_id 130 | 131 | def set_max_vertex_id(self, max_vertex_id: int): 132 | self._max_vertex_id = max_vertex_id 133 | 134 | def set_partition_table(self, partition_table: torch.Tensor): 135 | """ 136 | Set the partition table. 137 | 138 | Args: 139 | partition_table (torch.Tensor): The partition table. 140 | """ 141 | self._partition_table = partition_table 142 | 143 | def get_partition_table(self) -> torch.Tensor: 144 | """ 145 | Get the partition table. 146 | 147 | Returns: 148 | torch.Tensor: The partition table. 149 | """ 150 | if self._partition_table is None: 151 | raise RuntimeError('Partition table is not set.') 152 | return self._partition_table 153 | 154 | def set_num_partitions(self, num_partitions: int): 155 | """ 156 | Set the number of partitions. 157 | 158 | Args: 159 | num_partitions (int): The number of partitions. 160 | """ 161 | self._num_partitions = num_partitions 162 | 163 | def num_partitions(self) -> int: 164 | """ 165 | Get the number of partitions. 166 | 167 | Returns: 168 | int: The number of partitions. 169 | """ 170 | if self._num_partitions is None: 171 | raise RuntimeError('Number of partitions is not set.') 172 | return self._num_partitions 173 | 174 | def add_edges(self, source_vertices: np.ndarray, target_vertices: np.ndarray, 175 | timestamps: np.ndarray, eids: np.ndarray): 176 | return self._dgraph.add_edges(source_vertices, target_vertices, timestamps, eids) 177 | 178 | def out_degree(self, vertices: np.ndarray): 179 | return self._dgraph.out_degree(vertices) 180 | -------------------------------------------------------------------------------- /gnnflow/distributed/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from enum import Enum 3 | from threading import Lock 4 | 5 | 6 | class WorkStatus(Enum): 7 | """Work status.""" 8 | DOING = 0 9 | DONE = 1 10 | 11 | 12 | class HandleManager: 13 | """A thread-safe manager for handles. 14 | 15 | This class is used to manage handles for the distributed training. 16 | """ 17 | 18 | def __init__(self): 19 | # int -> WorkStatus 20 | self._last_handle = 0 21 | self._handles = defaultdict(lambda: WorkStatus.DOING) 22 | self._lock = Lock() 23 | 24 | def allocate_handle(self): 25 | """Allocate a handle. 26 | 27 | Returns: 28 | int: The handle. 29 | """ 30 | with self._lock: 31 | self._last_handle += 1 32 | handle = self._last_handle 33 | self._handles[handle] = WorkStatus.DOING 34 | return handle 35 | 36 | def mark_done(self, handle): 37 | """Mark a handle as done. 38 | 39 | Args: 40 | handle (int): The handle. 41 | """ 42 | with self._lock: 43 | self._handles[handle] = WorkStatus.DONE 44 | 45 | def poll(self, handle): 46 | """Poll a handle. 47 | 48 | Args: 49 | handle (int): The handle. 50 | 51 | Returns: 52 | bool: True if the handle is done, False otherwise. 53 | """ 54 | with self._lock: 55 | return self._handles[handle] == WorkStatus.DONE 56 | -------------------------------------------------------------------------------- /gnnflow/dynamic_graph.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | 5 | from libgnnflow import InsertionPolicy, MemoryResourceType, _DynamicGraph 6 | 7 | 8 | class DynamicGraph: 9 | """ 10 | A dynamic graph is a graph that can be updated at runtime. 11 | 12 | The dynamic graph is implemented as block adjacency list. It has a vertex 13 | table where each entry is a linked list of blocks. Each block contains 14 | a list of edges. Each edge is a tuple of (target_vertex, timestamp). 15 | """ 16 | 17 | def __init__( 18 | self, initial_pool_size: int, 19 | maximum_pool_size: int, 20 | mem_resource_type: str, 21 | minimum_block_size: int, 22 | blocks_to_preallocate: int, 23 | insertion_policy: str, 24 | source_vertices: Optional[np.ndarray] = None, 25 | target_vertices: Optional[np.ndarray] = None, 26 | timestamps: Optional[np.ndarray] = None, 27 | eids: Optional[np.ndarray] = None, 28 | add_reverse: bool = False, 29 | device: int = 0, 30 | adaptive_block_size: bool = True): 31 | """ 32 | The graph is initially empty and can be optionaly initialized with 33 | a list of edges. 34 | 35 | Args: 36 | initial_pool_size: optional, int, the initial pool size of the graph. 37 | maximum_pool_size: optional, int, the maximum pool size of the graph. 38 | mem_resource_type: optional, str, the memory resource type. 39 | valid options: ("cuda", "unified", "pinned", or "shared") (case insensitive). 40 | minimum_block_size: optional, int, the minimum block size of the graph. 41 | blocks_to_preallocate: optional, int, the number of blocks to preallocate. 42 | insertion_policy: the insertion policy to use 43 | valid options: ("insert" or "replace") (case insensitive). 44 | source_vertices: optional, 1D tensor, the source vertices of the edges. 45 | target_vertices: optional, 1D tensor, the target vertices of the edges. 46 | timestamps: optional, 1D tensor, the timestamps of the edges. 47 | eids: optional, 1D tensor, the edge ids of the edges. 48 | add_reverse: optional, bool, whether to add reverse edges. 49 | device: optional, int, the device to use. 50 | adaptive_block_size: optional, bool, whether to use adaptive block size. 51 | """ 52 | mem_resource_type = mem_resource_type.lower() 53 | if mem_resource_type == "cuda": 54 | mem_resource_type = MemoryResourceType.CUDA 55 | elif mem_resource_type == "unified": 56 | mem_resource_type = MemoryResourceType.UNIFIED 57 | elif mem_resource_type == "pinned": 58 | mem_resource_type = MemoryResourceType.PINNED 59 | elif mem_resource_type == "shared": 60 | mem_resource_type = MemoryResourceType.SHARED 61 | else: 62 | raise ValueError("Invalid memory resource type: {}".format( 63 | mem_resource_type)) 64 | 65 | insertion_policy = insertion_policy.lower() 66 | if insertion_policy == "insert": 67 | insertion_policy = InsertionPolicy.INSERT 68 | elif insertion_policy == "replace": 69 | insertion_policy = InsertionPolicy.REPLACE 70 | else: 71 | raise ValueError("Invalid insertion policy: {}".format( 72 | insertion_policy)) 73 | 74 | self._dgraph = _DynamicGraph( 75 | initial_pool_size, maximum_pool_size, mem_resource_type, 76 | minimum_block_size, blocks_to_preallocate, insertion_policy, 77 | device, adaptive_block_size) 78 | 79 | # initialize the graph with edges 80 | if source_vertices is not None and target_vertices is not None \ 81 | and timestamps is not None: 82 | self.add_edges(source_vertices, target_vertices, 83 | timestamps, eids, add_reverse) 84 | 85 | def add_edges( 86 | self, source_vertices: np.ndarray, target_vertices: np.ndarray, 87 | timestamps: np.ndarray, eids: Optional[np.ndarray] = None, add_reverse: bool = False): 88 | """ 89 | Add edges to the graph. Note that we do not assume that the incoming 90 | edges are sorted by timestamps. The function will sort the incoming 91 | edges by timestamps. 92 | 93 | Args: 94 | source_vertices: 1D tensor, the source vertices of the edges. 95 | target_vertices: 1D tensor, the target vertices of the edges. 96 | timestamps: 1D tensor, the timestamps of the edges. 97 | eids: 1D tensor, the edge ids of the edges. 98 | add_reverse: optional, bool, whether to add reverse edges. 99 | 100 | Raises: 101 | ValueError: if the timestamps are older than the existing edges in 102 | the graph. 103 | """ 104 | assert len(source_vertices.shape) == 1 and len( 105 | target_vertices.shape) == 1 and len(timestamps.shape) == 1, "Edges must be 1D tensors" 106 | 107 | assert source_vertices.shape[0] == target_vertices.shape[0] == \ 108 | timestamps.shape[0], "The number of source vertices, target vertices, timestamps, " \ 109 | "and edge ids must be the same." 110 | 111 | if eids is None: 112 | num_edges = self.num_edges() 113 | eids = np.arange(num_edges, num_edges + len(source_vertices)) 114 | 115 | if add_reverse: 116 | source_vertices_ext = np.concatenate( 117 | [source_vertices, target_vertices]) 118 | target_vertices_ext = np.concatenate( 119 | [target_vertices, source_vertices]) 120 | source_vertices = source_vertices_ext 121 | target_vertices = target_vertices_ext 122 | timestamps = np.concatenate([timestamps, timestamps]) 123 | eids = np.concatenate([eids, eids]) 124 | 125 | self._dgraph.add_edges( 126 | source_vertices, target_vertices, timestamps, eids) 127 | 128 | def offload_old_blocks(self, timestamp: float, to_file: bool = False): 129 | """ 130 | Offload the old blocks from the graph. 131 | 132 | Args: 133 | timestamp: the current timestamp. 134 | to_file: whether to offload the blocks to file. 135 | 136 | Return: 137 | the number of blocks offloaded. 138 | """ 139 | return self._dgraph.offload_old_blocks(timestamp, to_file) 140 | 141 | def num_vertices(self) -> int: 142 | return self._dgraph.num_vertices() 143 | 144 | def num_source_vertices(self) -> int: 145 | return self._dgraph.num_source_vertices() 146 | 147 | def max_vertex_id(self) -> int: 148 | return self._dgraph.max_vertex_id() 149 | 150 | def num_edges(self) -> int: 151 | return self._dgraph.num_edges() 152 | 153 | def out_degree(self, vertexs: np.ndarray) -> np.ndarray: 154 | return self._dgraph.out_degree(vertexs) 155 | 156 | def nodes(self) -> np.ndarray: 157 | """ 158 | Return the nodes of the graph. 159 | """ 160 | return self._dgraph.nodes() 161 | 162 | def src_nodes(self) -> np.ndarray: 163 | """ 164 | Return the source nodes of the graph. 165 | """ 166 | return self._dgraph.src_nodes() 167 | 168 | def edges(self) -> np.ndarray: 169 | """ 170 | Return the edges of the graph. 171 | """ 172 | return self._dgraph.edges() 173 | 174 | def get_temporal_neighbors(self, vertex: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 175 | """ 176 | Return the neighbors of the specified vertex. The neighbors are sorted 177 | by timestamps in decending order. 178 | 179 | Note that this function is inefficient and should be used sparingly. 180 | 181 | Args: 182 | vertex: the vertex to get neighbors for. 183 | 184 | Returns: A tuple of (target_vertices, timestamps, edge_ids) 185 | """ 186 | return self._dgraph.get_temporal_neighbors(vertex) 187 | 188 | def avg_linked_list_length(self) -> float: 189 | """ 190 | Return the average linked list length. 191 | """ 192 | return self._dgraph.avg_linked_list_length() 193 | 194 | def get_graph_memory_usage(self) -> int: 195 | """ 196 | Return the graph memory usage of the graph in bytes. 197 | """ 198 | return self._dgraph.get_graph_memory_usage() 199 | 200 | def get_metadata_memory_usage(self) -> int: 201 | """ 202 | Return the metadata memory usage of the graph in bytes. 203 | """ 204 | return self._dgraph.get_metadata_memory_usage() 205 | -------------------------------------------------------------------------------- /gnnflow/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterSH6/MSPipe/3e4f3efb998da621e756d07296025c63c385746a/gnnflow/models/__init__.py -------------------------------------------------------------------------------- /gnnflow/models/apan.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | import torch 3 | from gnnflow.distributed.kvstore import KVStoreClient 4 | 5 | from gnnflow.models.modules.layers import TransfomerAttentionLayer, IdentityNormLayer, EdgePredictor 6 | from gnnflow.models.modules.apan_memory import Memory 7 | from gnnflow.models.modules.memory_updater import TransformerMemoryUpdater 8 | 9 | 10 | class APAN(torch.nn.Module): 11 | 12 | def __init__(self, dim_node, dim_edge, dim_time=100, 13 | dim_embed=100, num_layers=1, num_snapshots=1, 14 | att_head=2, dropout=0.1, att_dropout=0.1, 15 | use_memory=True, dim_memory=100, 16 | num_nodes: Optional[int] = None, 17 | memory_device: Union[torch.device, str] = 'cpu', 18 | memory_shared: bool = False, 19 | kvstore_client: Optional[KVStoreClient] = None, 20 | mailbox_size=10, 21 | *args, **kwargs 22 | ): 23 | super(APAN, self).__init__() 24 | self.dim_node = dim_node 25 | self.dim_node_input = dim_node 26 | self.dim_edge = dim_edge 27 | 28 | self.dim_time = dim_time 29 | self.dim_embed = dim_embed 30 | self.num_layers = num_layers 31 | self.num_snapshots = num_snapshots 32 | self.att_head = att_head 33 | self.dropout = dropout 34 | self.att_dropout = att_dropout 35 | self.use_memory = use_memory 36 | 37 | self.gnn_layer = num_layers 38 | self.dropout = dropout 39 | self.attn_dropout = att_dropout 40 | 41 | self.mailbox_size = mailbox_size 42 | 43 | # TODO:.... 44 | self.memory = Memory(num_nodes, dim_edge, dim_memory, 45 | memory_device, memory_shared, mailbox_size, kvstore_client) 46 | 47 | # Memory updater 48 | self.memory_updater = TransformerMemoryUpdater( 49 | mailbox_size, att_head, 50 | 2 * dim_memory + dim_edge, 51 | dim_memory, 52 | dim_time, 53 | dropout, att_dropout) 54 | 55 | self.dim_node_input = dim_memory 56 | 57 | self.layers = torch.nn.ModuleDict() 58 | 59 | self.gnn_layer = 1 60 | for h in range(num_snapshots): 61 | self.layers['l0h' + 62 | str(h)] = IdentityNormLayer(self.dim_node_input) 63 | 64 | self.last_updated = None 65 | self.edge_predictor = EdgePredictor(dim_memory) 66 | 67 | def forward(self, mfgs, neg_samples=1): 68 | out = list() 69 | for l in range(self.gnn_layer): 70 | for h in range(self.num_snapshots): 71 | rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h]) 72 | if l != self.gnn_layer - 1: 73 | mfgs[l + 1][h].srcdata['h'] = rst 74 | else: 75 | out.append(rst) 76 | 77 | if self.num_snapshots == 1: 78 | out = out[0] 79 | else: 80 | out = torch.stack(out, dim=0) 81 | out = self.combiner(out)[0][-1, :, :] 82 | return self.edge_predictor(out) 83 | 84 | def get_emb(self, mfgs): 85 | self.memory_updater(mfgs[0]) 86 | out = list() 87 | for l in range(self.gnn_layer): 88 | for h in range(self.num_snapshots): 89 | rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h]) 90 | if l != self.gnn_layer - 1: 91 | mfgs[l + 1][h].srcdata['h'] = rst 92 | else: 93 | out.append(rst) 94 | if self.num_snapshots == 1: 95 | out = out[0] 96 | else: 97 | out = torch.stack(out, dim=0) 98 | out = self.combiner(out)[0][-1, :, :] 99 | return out 100 | 101 | def reset(self): 102 | if self.use_memory: 103 | self.memory.reset() 104 | 105 | def resize(self, num_nodes: int): 106 | if self.use_memory: 107 | self.memory.resize(num_nodes) 108 | 109 | def has_memory(self): 110 | return self.use_memory 111 | 112 | def backup_memory(self) -> Dict: 113 | if self.use_memory: 114 | return self.memory.backup() 115 | return {} 116 | 117 | def restore_memory(self, backup: Dict): 118 | if self.use_memory: 119 | self.memory.restore(backup) 120 | -------------------------------------------------------------------------------- /gnnflow/models/dgnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the implementation of TGL's model module. 3 | 4 | Implementation at: 5 | https://github.com/amazon-research/tgl/blob/main/modules.py 6 | """ 7 | from typing import Dict, List, Optional, Union 8 | 9 | import torch 10 | from dgl.heterograph import DGLBlock 11 | from gnnflow.distributed.kvstore import KVStoreClient 12 | from gnnflow.models.modules.layers import EdgePredictor, TransfomerAttentionLayer 13 | from gnnflow.models.modules.memory import Memory 14 | from gnnflow.models.modules.memory_updater import GRUMemeoryUpdater 15 | 16 | 17 | class DGNN(torch.nn.Module): 18 | """ 19 | Dynamic Graph Neural Model (DGNN) 20 | """ 21 | 22 | def __init__(self, dim_node: int, dim_edge: int, dim_time: int, 23 | dim_embed: int, num_layers: int, num_snapshots: int, 24 | att_head: int, dropout: float, att_dropout: float, 25 | use_memory: bool, dim_memory: Optional[int] = None, 26 | num_nodes: Optional[int] = None, 27 | memory_device: Union[torch.device, str] = 'cpu', 28 | memory_shared: bool = False, 29 | kvstore_client: Optional[KVStoreClient] = None, 30 | *args, **kwargs): 31 | """ 32 | Args: 33 | dim_node: dimension of node features/embeddings 34 | dim_edge: dimension of edge features 35 | dim_time: dimension of time features 36 | dim_embed: dimension of output embeddings 37 | num_layers: number of layers 38 | num_snapshots: number of snapshots 39 | att_head: number of heads for attention 40 | dropout: dropout rate 41 | att_dropout: dropout rate for attention 42 | use_memory: whether to use memory 43 | dim_memory: dimension of memory 44 | num_nodes: number of nodes in the graph 45 | memory_device: device of the memory 46 | memory_shared: whether to share memory across local workers 47 | kvstore_client: The KVStore_Client for fetching memorys when using partition 48 | """ 49 | super(DGNN, self).__init__() 50 | self.dim_node = dim_node 51 | self.dim_node_input = dim_node 52 | self.dim_edge = dim_edge 53 | self.dim_time = dim_time 54 | self.dim_embed = dim_embed 55 | self.num_layers = num_layers 56 | self.num_snapshots = num_snapshots 57 | self.att_head = att_head 58 | self.dropout = dropout 59 | self.att_dropout = att_dropout 60 | self.use_memory = use_memory 61 | 62 | if self.use_memory: 63 | assert num_snapshots == 1, 'memory is not supported for multiple snapshots' 64 | assert dim_memory is not None, 'dim_memory should be specified' 65 | assert num_nodes is not None, 'num_nodes is required when using memory' 66 | 67 | self.memory = Memory(num_nodes, dim_edge, dim_memory, 68 | memory_device, memory_shared, 69 | kvstore_client) 70 | 71 | self.memory_updater = GRUMemeoryUpdater( 72 | dim_node, dim_edge, dim_time, dim_embed, dim_memory) 73 | dim_node = dim_memory 74 | 75 | self.layers = torch.nn.ModuleDict() 76 | for l in range(num_layers): 77 | for h in range(num_snapshots): 78 | if l == 0: 79 | dim_node_input = dim_node 80 | else: 81 | dim_node_input = dim_embed 82 | 83 | key = 'l' + str(l) + 'h' + str(h) 84 | self.layers[key] = TransfomerAttentionLayer(dim_node_input, 85 | dim_edge, 86 | dim_time, 87 | dim_embed, 88 | att_head, 89 | dropout, 90 | att_dropout) 91 | 92 | if self.num_snapshots > 1: 93 | self.combiner = torch.nn.RNN( 94 | dim_embed, dim_embed) 95 | 96 | self.last_updated = None 97 | self.edge_predictor = EdgePredictor(dim_embed) 98 | 99 | def reset(self): 100 | if self.use_memory: 101 | self.memory.reset() 102 | 103 | def resize(self, num_nodes: int): 104 | if self.use_memory: 105 | self.memory.resize(num_nodes) 106 | 107 | def has_memory(self): 108 | return self.use_memory 109 | 110 | def backup_memory(self) -> Dict: 111 | if self.use_memory: 112 | return self.memory.backup() 113 | return {} 114 | 115 | def restore_memory(self, backup: Dict): 116 | if self.use_memory: 117 | self.memory.restore(backup) 118 | 119 | def forward(self, mfgs: List[List[DGLBlock]], return_embed: bool =False, neg_samples: int =1): 120 | """ 121 | Args: 122 | mfgs: list of list of DGLBlocks 123 | neg_sample_ratio: negative sampling ratio 124 | """ 125 | out = list() 126 | for l in range(self.num_layers): 127 | for h in range(self.num_snapshots): 128 | key = 'l' + str(l) + 'h' + str(h) 129 | rst = self.layers[key](mfgs[l][h]) 130 | if l != self.num_layers - 1: 131 | mfgs[l + 1][h].srcdata['h'] = rst 132 | else: 133 | out.append(rst) 134 | 135 | if self.num_snapshots == 1: 136 | embed = out[0] 137 | else: 138 | embed = torch.stack(out, dim=0) 139 | embed = self.combiner(embed)[0][-1, :, :] 140 | 141 | if return_embed: 142 | return embed 143 | return self.edge_predictor(embed, neg_samples) 144 | -------------------------------------------------------------------------------- /gnnflow/models/gat.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import dgl.nn as dglnn 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from dgl.heterograph import DGLBlock 7 | 8 | 9 | class GAT(nn.Module): 10 | def __init__(self, dim_in: int, dim_out: int, 11 | num_layers: int = 2, 12 | attn_head: List[int] = [8, 1], 13 | feat_drop: float = 0, 14 | attn_drop: float = 0, 15 | allow_zero_in_degree: bool = True): 16 | if num_layers != len(attn_head): 17 | raise ValueError( 18 | "length of attn head {} must equal to num_layers {}".format( 19 | attn_head, num_layers)) 20 | super().__init__() 21 | self.num_layers = num_layers 22 | self.layers = nn.ModuleDict() 23 | # TODO: gat should deal with zero in-degree problem 24 | for l in range(num_layers): 25 | # static graph doesn't have snapshot 26 | key = 'l' + str(l) + 'h' + str(0) 27 | if l == 0: 28 | self.layers[key] = dglnn.GATConv( 29 | dim_in, 30 | dim_out, 31 | attn_head[0], 32 | feat_drop=feat_drop, 33 | attn_drop=attn_drop, 34 | activation=F.elu, 35 | allow_zero_in_degree=allow_zero_in_degree 36 | ) 37 | else: 38 | self.layers[key] = dglnn.GATConv( 39 | dim_out * attn_head[l-1], 40 | dim_out, 41 | attn_head[l], 42 | feat_drop=feat_drop, 43 | attn_drop=attn_drop, 44 | activation=None, 45 | allow_zero_in_degree=allow_zero_in_degree 46 | ) 47 | 48 | self.dim_out = dim_out 49 | # use the same predictor as graphSage 50 | self.predictor = nn.Sequential( 51 | nn.Linear(dim_out, dim_out), 52 | nn.ReLU(), 53 | nn.Linear(dim_out, dim_out), 54 | nn.ReLU(), 55 | nn.Linear(dim_out, 1)) 56 | 57 | def reset(self): 58 | pass 59 | 60 | def forward(self, mfgs: List[List[DGLBlock]], neg_sample_ratio: int = 1, *args, **kwargs): 61 | for l in range(self.num_layers): 62 | key = 'l' + str(l) + 'h' + str(0) 63 | h = self.layers[key](mfgs[l][0], mfgs[l][0].srcdata['h']) 64 | if l != self.num_layers - 1: # not last layer 65 | h = h.flatten(1) 66 | mfgs[l + 1][0].srcdata['h'] = h 67 | else: 68 | # last layer use mean 69 | h = h.mean(1) 70 | 71 | num_edge = h.shape[0] // (neg_sample_ratio + 2) 72 | src_h = h[:num_edge] 73 | pos_dst_h = h[num_edge:2 * num_edge] 74 | neg_dst_h = h[2 * num_edge:] 75 | h_pos = self.predictor(src_h * pos_dst_h) 76 | # TODO: it seems that neg sample of static graph is different from dynamic 77 | h_neg = self.predictor(src_h.tile(neg_sample_ratio, 1) * neg_dst_h) 78 | return h_pos, h_neg 79 | -------------------------------------------------------------------------------- /gnnflow/models/graphsage.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import dgl 4 | import dgl.nn as dglnn 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from dgl.heterograph import DGLBlock 9 | 10 | 11 | class SAGE(nn.Module): 12 | def __init__(self, dim_node: int, dim_out: int, 13 | num_layers: int = 2, 14 | aggregator: Optional[str] = 'mean'): 15 | 16 | if aggregator not in ['mean', 'gcn', 'pool', 'lstm']: 17 | raise ValueError( 18 | "aggregator {} is not in ['mean', 'gcn', 'pool', 'lstm']".format(aggregator)) 19 | super().__init__() 20 | self.num_layers = num_layers 21 | 22 | self.layers = nn.ModuleDict() 23 | for l in range(num_layers): 24 | # static graph doesn't have snapshot 25 | key = 'l' + str(l) + 'h' + str(0) 26 | if l == 0: 27 | self.layers[key] = dglnn.SAGEConv( 28 | dim_node, dim_out, aggregator) 29 | else: 30 | self.layers[key] = dglnn.SAGEConv( 31 | dim_out, dim_out, aggregator) 32 | 33 | self.dim_out = dim_out 34 | self.predictor = nn.Sequential( 35 | nn.Linear(dim_out, dim_out), 36 | nn.ReLU(), 37 | nn.Linear(dim_out, dim_out), 38 | nn.ReLU(), 39 | nn.Linear(dim_out, 1)) 40 | 41 | def reset(self): 42 | pass 43 | 44 | def forward(self, mfgs: List[List[DGLBlock]], neg_sample_ratio: int = 1, *args, **kwargs): 45 | """ 46 | Args: 47 | b: sampled message flow graph (mfg), where 48 | `b.num_dst_nodes()` is the number of target nodes to sample, 49 | `b.srcdata['h']` is the embedding of all nodes, 50 | `b.edge['f']` is the edge features of sampled edges, and 51 | `b.edata['dt']` is the delta time of sampled edges. 52 | 53 | Returns: 54 | output: output embedding of target nodes (shape: (num_dst_nodes, dim_out)) 55 | """ 56 | for l in range(self.num_layers): 57 | key = 'l' + str(l) + 'h' + str(0) 58 | h = self.layers[key](mfgs[l][0], mfgs[l][0].srcdata['h']) 59 | if l != self.num_layers - 1: 60 | h = F.relu(h) 61 | mfgs[l + 1][0].srcdata['h'] = h 62 | 63 | num_edge = h.shape[0] // (neg_sample_ratio + 2) 64 | src_h = h[:num_edge] 65 | pos_dst_h = h[num_edge:2 * num_edge] 66 | neg_dst_h = h[2 * num_edge:] 67 | h_pos = self.predictor(src_h * pos_dst_h) 68 | # TODO: it seems that neg sample of static graph is different from dynamic 69 | h_neg = self.predictor(src_h.tile(neg_sample_ratio, 1) * neg_dst_h) 70 | return h_pos, h_neg 71 | -------------------------------------------------------------------------------- /gnnflow/models/jodie.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | import torch 3 | 4 | from gnnflow.distributed.kvstore import KVStoreClient 5 | from gnnflow.models.modules.layers import EdgePredictor, IdentityNormLayer, JODIETimeEmbedding 6 | from gnnflow.models.modules.memory import Memory 7 | from gnnflow.models.modules.memory_updater import RNNMemeoryUpdater 8 | 9 | 10 | class JODIE(torch.nn.Module): 11 | 12 | def __init__(self, dim_node, dim_edge, dim_time=100, 13 | dim_embed=100, num_layers=1, num_snapshots=1, 14 | att_head=2, dropout=0.1, att_dropout=0.1, 15 | use_memory=True, dim_memory=100, 16 | num_nodes: Optional[int] = None, 17 | memory_device: Union[torch.device, str] = 'cpu', 18 | memory_shared: bool = False, 19 | kvstore_client: Optional[KVStoreClient] = None, 20 | *args, **kwargs): 21 | super(JODIE, self).__init__() 22 | self.dim_node = dim_node 23 | self.dim_node_input = dim_node 24 | self.dim_edge = dim_edge 25 | 26 | self.dim_time = dim_time 27 | self.dim_embed = dim_embed 28 | self.num_layers = num_layers 29 | self.num_snapshots = num_snapshots 30 | self.att_head = att_head 31 | self.dropout = dropout 32 | self.att_dropout = att_dropout 33 | self.use_memory = use_memory 34 | 35 | self.gnn_layer = num_layers 36 | self.dropout = dropout 37 | self.attn_dropout = att_dropout 38 | 39 | # Use Memory 40 | self.memory = Memory(num_nodes, dim_edge, dim_memory, 41 | memory_device, memory_shared, kvstore_client) 42 | 43 | # Memory updater 44 | self.memory_updater = RNNMemeoryUpdater( 45 | dim_node, dim_edge, dim_time, dim_embed, dim_memory) 46 | 47 | self.dim_node_input = dim_embed 48 | 49 | self.layers = torch.nn.ModuleDict() 50 | 51 | self.gnn_layer = 1 52 | for h in range(num_snapshots): 53 | self.layers['l0h' + 54 | str(h)] = IdentityNormLayer(self.dim_node_input) 55 | self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(dim_embed) 56 | 57 | self.edge_predictor = EdgePredictor(dim_embed) 58 | 59 | def forward(self, mfgs, neg_samples=1): 60 | out = list() 61 | for l in range(self.gnn_layer): 62 | for h in range(self.num_snapshots): 63 | rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h]) 64 | rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l] 65 | [h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts']) 66 | 67 | if l != self.gnn_layer - 1: 68 | mfgs[l + 1][h].srcdata['h'] = rst 69 | else: 70 | out.append(rst) 71 | 72 | if self.num_snapshots == 1: 73 | out = out[0] 74 | else: 75 | out = torch.stack(out, dim=0) 76 | out = self.combiner(out)[0][-1, :, :] 77 | return self.edge_predictor(out) 78 | 79 | def get_emb(self, mfgs): 80 | self.memory_updater(mfgs[0]) 81 | out = list() 82 | for l in range(self.gnn_layer): 83 | for h in range(self.num_snapshots): 84 | rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h]) 85 | if l != self.gnn_layer - 1: 86 | mfgs[l + 1][h].srcdata['h'] = rst 87 | else: 88 | out.append(rst) 89 | if self.num_snapshots == 1: 90 | out = out[0] 91 | else: 92 | out = torch.stack(out, dim=0) 93 | out = self.combiner(out)[0][-1, :, :] 94 | return out 95 | 96 | def reset(self): 97 | if self.use_memory: 98 | self.memory.reset() 99 | 100 | def resize(self, num_nodes: int): 101 | if self.use_memory: 102 | self.memory.resize(num_nodes) 103 | 104 | def has_memory(self): 105 | return self.use_memory 106 | 107 | def backup_memory(self) -> Dict: 108 | if self.use_memory: 109 | return self.memory.backup() 110 | return {} 111 | 112 | def restore_memory(self, backup: Dict): 113 | if self.use_memory: 114 | self.memory.restore(backup) -------------------------------------------------------------------------------- /gnnflow/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterSH6/MSPipe/3e4f3efb998da621e756d07296025c63c385746a/gnnflow/models/modules/__init__.py -------------------------------------------------------------------------------- /gnnflow/models/modules/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the implementation of TGL's layer module. 3 | 4 | Implementation at: 5 | https://github.com/amazon-research/tgl/blob/main/layers.py 6 | """ 7 | import logging 8 | import math 9 | import dgl 10 | import dgl.function as fn 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from dgl.heterograph import DGLBlock 16 | 17 | 18 | class TimeEncode(torch.nn.Module): 19 | """ 20 | Time encoding layer proposed by TGAT 21 | """ 22 | 23 | def __init__(self, dim_time: int): 24 | """ 25 | Args: 26 | dim: dimension of time features 27 | """ 28 | super(TimeEncode, self).__init__() 29 | self.w = torch.nn.Linear(1, dim_time) 30 | self.w.weight = nn.parameter.Parameter((torch.from_numpy( 31 | 1 / 10 ** np.linspace(0, 9, dim_time, dtype=np.float32))). 32 | reshape(dim_time, -1)) 33 | self.w.bias = nn.parameter.Parameter(torch.zeros(dim_time)) 34 | 35 | def forward(self, delta_time: torch.Tensor): 36 | output = torch.cos(self.w(delta_time.reshape((-1, 1)))) 37 | return output 38 | 39 | class FixTimeEncode(nn.Module): 40 | """ 41 | out = linear(time_scatter): 1-->time_dims 42 | out = cos(out) 43 | """ 44 | def __init__(self, dim): 45 | super(FixTimeEncode, self).__init__() 46 | self.dim = dim 47 | self.w = nn.Linear(1, dim) 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self, ): 51 | self.w.weight = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.dim, dtype=np.float32))).reshape(self.dim, -1)) 52 | self.w.bias = nn.Parameter(torch.zeros(self.dim)) 53 | 54 | self.w.weight.requires_grad = False 55 | self.w.bias.requires_grad = False 56 | 57 | @torch.no_grad() 58 | def forward(self, t): 59 | output = torch.cos(self.w(t.reshape((-1, 1)))) 60 | return output 61 | 62 | 63 | class TransfomerAttentionLayer(torch.nn.Module): 64 | """ 65 | Transfomer attention layer 66 | """ 67 | 68 | def __init__(self, dim_node: int, dim_edge: int, dim_time: int, 69 | dim_out: int, num_head: int, dropout: float, att_dropout: float): 70 | """ 71 | Args: 72 | dim_node: dimension of node features/embeddings 73 | dim_edge: dimension of edge features 74 | dim_time: dimension of time features 75 | dim_out: dimension of output embeddings 76 | num_head: number of heads 77 | dropout: dropout rate 78 | att_dropout: dropout rate for attention 79 | """ 80 | super(TransfomerAttentionLayer, self).__init__() 81 | # assert dim_node > 0 or dim_edge > 0, \ 82 | # "either dim_node or dim_edge should be positive" 83 | 84 | self.use_node_feat = dim_node > 0 85 | self.use_edge_feat = dim_edge > 0 86 | self.use_time_enc = dim_time > 0 87 | 88 | self.dim_node = dim_node 89 | self.num_head = num_head 90 | self.dim_time = dim_time 91 | self.dim_out = dim_out 92 | self.dropout = torch.nn.Dropout(dropout) 93 | self.att_dropout = torch.nn.Dropout(att_dropout) 94 | self.att_act = torch.nn.LeakyReLU(0.2) 95 | 96 | if self.use_time_enc: 97 | self.time_enc = TimeEncode(dim_time) 98 | # self.time_enc = FixTimeEncode(dim_time) 99 | 100 | if self.use_node_feat or self.use_time_enc: 101 | self.w_q = torch.nn.Linear(dim_node + dim_time, dim_out) 102 | else: 103 | self.w_q = torch.nn.Identity() 104 | 105 | self.w_k = torch.nn.Linear( 106 | dim_node + dim_edge + dim_time, dim_out) 107 | self.w_v = torch.nn.Linear( 108 | dim_node + dim_edge + dim_time, dim_out) 109 | 110 | self.w_out = torch.nn.Linear(dim_node + dim_out, dim_out) 111 | 112 | self.layer_norm = torch.nn.LayerNorm(dim_out) 113 | 114 | def forward(self, b: DGLBlock): 115 | """ 116 | Args: 117 | b: sampled message flow graph (mfg), where 118 | `b.num_dst_nodes()` is the number of target nodes to sample, 119 | `b.srcdata['h']` is the embedding of all nodes, 120 | `b.edge['f']` is the edge features of sampled edges, and 121 | `b.edata['dt']` is the delta time of sampled edges. 122 | 123 | Returns: 124 | output: output embedding of target nodes (shape: (num_dst_nodes, dim_out)) 125 | """ 126 | num_edges = b.num_edges() 127 | num_dst_nodes = b.num_dst_nodes() 128 | device = b.device 129 | 130 | # sample nothing (no neighbors) 131 | if num_edges == 0: 132 | return torch.zeros((num_dst_nodes, self.dim_out), device=device) 133 | 134 | if self.use_node_feat: 135 | target_node_embeddings = b.srcdata['h'][:num_dst_nodes] 136 | source_node_embeddings = b.srcdata['h'][num_dst_nodes:] 137 | else: 138 | # dummy node embeddings 139 | if self.use_time_enc: 140 | target_node_embeddings = torch.zeros( 141 | (num_dst_nodes, 0), device=device) 142 | else: 143 | target_node_embeddings = torch.ones( 144 | (num_dst_nodes, self.dim_out), device=device) 145 | 146 | source_node_embeddings = torch.zeros( 147 | (num_edges, 0), device=device) 148 | 149 | if self.use_edge_feat: 150 | edge_feats = b.edata['f'] 151 | else: 152 | # dummy edge features 153 | edge_feats = torch.zeros((num_edges, 0), device=device) 154 | 155 | if self.use_time_enc: 156 | delta_time = b.edata['dt'] 157 | time_feats = self.time_enc(delta_time) 158 | zero_time_feats = self.time_enc(torch.zeros( 159 | num_dst_nodes, dtype=torch.float32, device=device)) 160 | else: 161 | # dummy time features 162 | time_feats = torch.zeros((num_edges, 0), device=device) 163 | zero_time_feats = torch.zeros((num_dst_nodes, 0), device=device) 164 | 165 | assert isinstance(edge_feats, torch.Tensor) 166 | Q = torch.cat([target_node_embeddings, zero_time_feats], dim=1) 167 | K = torch.cat([source_node_embeddings, edge_feats, time_feats], dim=1) 168 | V = torch.cat([source_node_embeddings, edge_feats, time_feats], dim=1) 169 | 170 | Q = self.w_q(Q)[b.edges()[1]] 171 | K = self.w_k(K) 172 | V = self.w_v(V) 173 | 174 | Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1)) 175 | K = torch.reshape(K, (K.shape[0], self.num_head, -1)) 176 | V = torch.reshape(V, (V.shape[0], self.num_head, -1)) 177 | 178 | # compute attention scores 179 | att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2))) 180 | att = self.att_dropout(att) 181 | V = torch.reshape(V*att[:, :, None], (V.shape[0], -1)) 182 | 183 | b.srcdata['v'] = torch.cat((torch.zeros( 184 | (num_dst_nodes, V.shape[1]), device=device), V), dim=0) 185 | b.update_all(fn.copy_src('v', 'm'), fn.sum('m', 'h')) 186 | 187 | if self.use_node_feat: 188 | rst = torch.cat((b.dstdata['h'], target_node_embeddings), dim=1) 189 | else: 190 | rst = b.dstdata['h'] 191 | 192 | rst = self.w_out(rst) 193 | rst = F.relu(self.dropout(rst)) 194 | return self.layer_norm(rst) 195 | 196 | 197 | class EdgePredictor(torch.nn.Module): 198 | """ 199 | Edge prediction layer 200 | """ 201 | 202 | def __init__(self, dim_embed: int): 203 | """ 204 | Args: 205 | dim: dimension of embedding 206 | """ 207 | super(EdgePredictor, self).__init__() 208 | self.src_fc = torch.nn.Linear(dim_embed, dim_embed) 209 | self.dst_fc = torch.nn.Linear(dim_embed, dim_embed) 210 | self.out_fc = torch.nn.Linear(dim_embed, 1) 211 | 212 | def forward(self, h: torch.Tensor, neg_samples: int = 1): 213 | """ 214 | Args: 215 | h: embeddings of source, destination and negative sampling nodes 216 | """ 217 | num_edge = h.shape[0] // (neg_samples + 2) 218 | src_h, pos_dst_h, neg_dst_h = h.tensor_split((num_edge, 2 * num_edge)) 219 | src_h = self.src_fc(src_h) 220 | pos_dst_h = self.dst_fc(pos_dst_h) 221 | neg_dst_h = self.dst_fc(neg_dst_h) 222 | pos_edge = F.relu(src_h + pos_dst_h) 223 | neg_edge = F.relu(src_h.tile(neg_samples, 1) + neg_dst_h) 224 | return self.out_fc(pos_edge), self.out_fc(neg_edge) 225 | 226 | 227 | class MLP(torch.nn.Module): 228 | """ 229 | Node classification 230 | """ 231 | 232 | def __init__(self, dim_in, dim_hid, num_class): 233 | super(MLP, self).__init__() 234 | self.fc1 = torch.nn.Linear(dim_in, dim_hid) 235 | self.fc2 = torch.nn.Linear(dim_hid, num_class) 236 | 237 | def forward(self, x): 238 | x = self.fc1(x) 239 | x = torch.nn.functional.relu(x) 240 | x = self.fc2(x) 241 | return x 242 | 243 | 244 | class IdentityNormLayer(torch.nn.Module): 245 | 246 | def __init__(self, dim_out): 247 | super(IdentityNormLayer, self).__init__() 248 | self.norm = torch.nn.LayerNorm(dim_out) 249 | 250 | def forward(self, b): 251 | return self.norm(b.srcdata['h']) 252 | 253 | 254 | class JODIETimeEmbedding(torch.nn.Module): 255 | 256 | def __init__(self, dim_out): 257 | super(JODIETimeEmbedding, self).__init__() 258 | self.dim_out = dim_out 259 | 260 | class NormalLinear(torch.nn.Linear): 261 | # From Jodie code 262 | def reset_parameters(self): 263 | stdv = 1. / math.sqrt(self.weight.size(1)) 264 | self.weight.data.normal_(0, stdv) 265 | if self.bias is not None: 266 | self.bias.data.normal_(0, stdv) 267 | 268 | self.time_emb = NormalLinear(1, dim_out) 269 | 270 | def forward(self, h, mem_ts, ts): 271 | time_diff = (ts - mem_ts) / (ts + 1) 272 | rst = h * (1 + self.time_emb(time_diff.unsqueeze(1))) 273 | return rst 274 | -------------------------------------------------------------------------------- /gnnflow/models/modules/memory_updater.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on the implementation of TGL's memory module. 3 | 4 | Implementation at: 5 | https://github.com/amazon-research/tgl/blob/main/memorys.py 6 | """ 7 | import logging 8 | import os 9 | import torch 10 | from dgl.heterograph import DGLBlock 11 | 12 | from gnnflow.models.modules.layers import FixTimeEncode, TimeEncode 13 | 14 | 15 | class GRUMemeoryUpdater(torch.nn.Module): 16 | """ 17 | GRU memory updater proposed by TGN 18 | """ 19 | 20 | def __init__(self, dim_node: int, dim_edge: int, dim_time: int, 21 | dim_embed: int, dim_memory: int): 22 | """ 23 | Args: 24 | dim_node: dimension of node features/embeddings 25 | dim_edge: dimension of edge features 26 | dim_time: dimension of time features 27 | dim_embed: dimension of output embeddings 28 | dim_memory: dimension of memory 29 | """ 30 | super(GRUMemeoryUpdater, self).__init__() 31 | self.dim_message = 2 * dim_memory + dim_edge 32 | self.dim_node = dim_node 33 | self.dim_time = dim_time 34 | self.dim_embed = dim_embed 35 | self.updater = torch.nn.GRUCell( 36 | self.dim_message + self.dim_time, dim_memory) 37 | 38 | self.use_time_enc = dim_time > 0 39 | if self.use_time_enc: 40 | self.time_enc = TimeEncode(dim_time) 41 | # self.time_enc = FixTimeEncode(dim_time) 42 | 43 | if dim_node > 0 and dim_node != dim_memory: 44 | self.node_feat_proj = torch.nn.Linear(dim_node, dim_memory) 45 | 46 | def forward(self, b: DGLBlock): 47 | """ 48 | Update the memory of nodes 49 | 50 | Args: 51 | b: sampled message flow graph (mfg), where 52 | `b.num_dst_nodes()` is the number of target nodes to sample, 53 | `b.srcdata['ID']` is the node IDs of all nodes, and 54 | `b.srcdata['ts']` is the timestamp of all nodes. 55 | 56 | Return: 57 | last_updated: { 58 | "last_updated_nid": node IDs of the target nodes 59 | "last_updated_memory": updated memory of the target nodes 60 | "last_updated_ts": timestamp of the target nodes 61 | } 62 | """ 63 | device = b.device 64 | 65 | if self.use_time_enc: 66 | time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts']) 67 | b.srcdata['mem_input'] = torch.cat( 68 | [b.srcdata['mem_input'], time_feat], dim=1) 69 | 70 | # TODO: A fake mail 71 | updated_memory = self.updater( 72 | b.srcdata['mem_input'], b.srcdata['mem']) 73 | 74 | # if int(os.environ['LOCAL_RANK']) == 0: 75 | # logging.info('mem input: {}'.format(b.srcdata['mem_input'])) 76 | # logging.info('mem : {}'.format(b.srcdata['mem'])) 77 | # logging.info('updated_memory: {}'.format(updated_memory)) 78 | # for name, param in self.updater.named_parameters(): 79 | # logging.info("name: {} param: {}".format(name, param[0])) 80 | 81 | num_dst_nodes = b.num_dst_nodes() 82 | last_updated_nid = b.srcdata['ID'][:num_dst_nodes].clone( 83 | ).detach().to(device) 84 | last_updated_memory = updated_memory[:num_dst_nodes].clone( 85 | ).detach().to(device) 86 | last_updated_ts = b.srcdata['ts'][:num_dst_nodes].clone( 87 | ).detach().to(device) 88 | 89 | if self.dim_node > 0: 90 | if self.dim_node == self.dim_embed: 91 | b.srcdata['h'] += updated_memory 92 | else: 93 | b.srcdata['h'] = updated_memory + \ 94 | self.node_feat_proj(b.srcdata['h']) 95 | else: 96 | b.srcdata['h'] = updated_memory 97 | 98 | return { 99 | "last_updated_nid": last_updated_nid, 100 | "last_updated_memory": last_updated_memory, 101 | "last_updated_ts": last_updated_ts 102 | } 103 | 104 | 105 | class RNNMemeoryUpdater(torch.nn.Module): 106 | """ 107 | RNN memory updater proposed by JODIE 108 | """ 109 | 110 | def __init__(self, dim_node: int, dim_edge: int, dim_time: int, 111 | dim_embed: int, dim_memory: int): 112 | """ 113 | Args: 114 | dim_node: dimension of node features/embeddings 115 | dim_edge: dimension of edge features 116 | dim_time: dimension of time features 117 | dim_embed: dimension of output embeddings 118 | dim_memory: dimension of memory 119 | """ 120 | super(RNNMemeoryUpdater, self).__init__() 121 | self.dim_message = 2 * dim_memory + dim_edge 122 | self.dim_node = dim_node 123 | self.dim_time = dim_time 124 | self.dim_embed = dim_embed 125 | 126 | self.updater = torch.nn.RNNCell( 127 | self.dim_message + self.dim_time, dim_memory) 128 | 129 | self.use_time_enc = dim_time > 0 130 | if self.use_time_enc: 131 | self.time_enc = TimeEncode(dim_time) 132 | 133 | if dim_node > 0 and dim_node != dim_memory: 134 | self.node_feat_proj = torch.nn.Linear(dim_node, dim_memory) 135 | 136 | def forward(self, b: DGLBlock): 137 | """ 138 | Update the memory of nodes 139 | 140 | Args: 141 | b: sampled message flow graph (mfg), where 142 | `b.num_dst_nodes()` is the number of target nodes to sample, 143 | `b.srcdata['ID']` is the node IDs of all nodes, and 144 | `b.srcdata['ts']` is the timestamp of all nodes. 145 | 146 | Return: 147 | last_updated: { 148 | "last_updated_nid": node IDs of the target nodes 149 | "last_updated_memory": updated memory of the target nodes 150 | "last_updated_ts": timestamp of the target nodes 151 | } 152 | """ 153 | device = b.device 154 | 155 | if self.use_time_enc: 156 | time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts']) 157 | b.srcdata['mem_input'] = torch.cat( 158 | [b.srcdata['mem_input'], time_feat], dim=1) 159 | 160 | updated_memory = self.updater( 161 | b.srcdata['mem_input'], b.srcdata['mem']) 162 | 163 | # if int(os.environ['LOCAL_RANK']) == 0: 164 | # logging.info('mem input: {}'.format(b.srcdata['mem_input'])) 165 | # logging.info('mem : {}'.format(b.srcdata['mem'])) 166 | # logging.info('updated_memory: {}'.format(updated_memory)) 167 | # for name, param in self.updater.named_parameters(): 168 | # logging.info("name: {} param: {}".format(name, param[0])) 169 | 170 | num_dst_nodes = b.num_dst_nodes() 171 | last_updated_nid = b.srcdata['ID'][:num_dst_nodes].clone( 172 | ).detach().to(device) 173 | last_updated_memory = updated_memory[:num_dst_nodes].clone( 174 | ).detach().to(device) 175 | last_updated_ts = b.srcdata['ts'][:num_dst_nodes].clone( 176 | ).detach().to(device) 177 | 178 | if self.dim_node > 0: 179 | if self.dim_node == self.dim_embed: 180 | b.srcdata['h'] += updated_memory 181 | else: 182 | b.srcdata['h'] = updated_memory + \ 183 | self.node_feat_proj(b.srcdata['h']) 184 | else: 185 | b.srcdata['h'] = updated_memory 186 | 187 | return { 188 | "last_updated_nid": last_updated_nid, 189 | "last_updated_memory": last_updated_memory, 190 | "last_updated_ts": last_updated_ts 191 | } 192 | 193 | 194 | class TransformerMemoryUpdater(torch.nn.Module): 195 | 196 | def __init__(self, mailbox_size, att_head, dim_in, dim_out, dim_time, dropout, att_dropout): 197 | super(TransformerMemoryUpdater, self).__init__() 198 | self.mailbox_size = mailbox_size 199 | self.dim_time = dim_time 200 | self.att_h = att_head 201 | if dim_time > 0: 202 | self.time_enc = TimeEncode(dim_time) 203 | self.w_q = torch.nn.Linear(dim_out, dim_out) 204 | self.w_k = torch.nn.Linear(dim_in + dim_time, dim_out) 205 | self.w_v = torch.nn.Linear(dim_in + dim_time, dim_out) 206 | self.att_act = torch.nn.LeakyReLU(0.2) 207 | self.layer_norm = torch.nn.LayerNorm(dim_out) 208 | self.mlp = torch.nn.Linear(dim_out, dim_out) 209 | self.dropout = torch.nn.Dropout(dropout) 210 | self.att_dropout = torch.nn.Dropout(att_dropout) 211 | self.last_updated_memory = None 212 | self.last_updated_ts = None 213 | self.last_updated_nid = None 214 | 215 | def forward(self, b): 216 | # for b in mfg: 217 | Q = self.w_q(b.srcdata['mem']).reshape( 218 | (b.num_src_nodes(), self.att_h, -1)) 219 | # logging.info("b.srcdata['mem_input'] {}".format(b.srcdata['mem_input'].shape)) 220 | mails = b.srcdata['mem_input'].reshape( 221 | (b.num_src_nodes(), self.mailbox_size, -1)) 222 | if self.dim_time > 0: 223 | time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape( 224 | (b.num_src_nodes(), self.mailbox_size, -1)) 225 | mails = torch.cat([mails, time_feat], dim=2) 226 | K = self.w_k(mails).reshape( 227 | (b.num_src_nodes(), self.mailbox_size, self.att_h, -1)) 228 | V = self.w_v(mails).reshape( 229 | (b.num_src_nodes(), self.mailbox_size, self.att_h, -1)) 230 | att = self.att_act((Q[:, None, :, :]*K).sum(dim=3)) 231 | att = torch.nn.functional.softmax(att, dim=1) 232 | att = self.att_dropout(att) 233 | rst = (att[:, :, :, None]*V).sum(dim=1) 234 | rst = rst.reshape((rst.shape[0], -1)) 235 | rst += b.srcdata['mem'] 236 | rst = self.layer_norm(rst) 237 | rst = self.mlp(rst) 238 | rst = self.dropout(rst) 239 | rst = torch.nn.functional.relu(rst) 240 | b.srcdata['h'] = rst 241 | self.last_updated_memory = rst.detach().clone() 242 | self.last_updated_nid = b.srcdata['ID'].detach().clone() 243 | self.last_updated_ts = b.srcdata['ts'].detach().clone() 244 | 245 | return { 246 | "last_updated_nid": self.last_updated_nid, 247 | "last_updated_memory": self.last_updated_memory, 248 | "last_updated_ts": self.last_updated_ts 249 | } 250 | -------------------------------------------------------------------------------- /gnnflow/temporal_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import dgl 4 | import numpy as np 5 | import torch 6 | from dgl.heterograph import DGLBlock 7 | 8 | from libgnnflow import SamplingPolicy, SamplingResult, _TemporalSampler 9 | 10 | from .dynamic_graph import DynamicGraph 11 | 12 | 13 | class TemporalSampler: 14 | """ 15 | TemporalSampler samples k-hop multi-snapshots neighbors of given vertices. 16 | """ 17 | 18 | def __init__( 19 | self, graph: DynamicGraph, fanouts: List[int], 20 | sample_strategy: str = "recent", num_snapshots: int = 1, 21 | snapshot_time_window: float = 0.0, prop_time: bool = False, 22 | seed: int = 1234, *args, **kwargs): 23 | """ 24 | Initialize the sampler. 25 | 26 | Args: 27 | graph: the dynamic graph. 28 | fanouts: fanouts of each layer. 29 | samplle_strategy: sampling strategy, 'recent' or 'uniform' (case insensitive). 30 | num_snapshots: number of snapshots to sample. 31 | snapshot_time_window: time window every snapshot cover. It only makes 32 | sense when num_snapshots > 1. 33 | prop_time: whether to propagate timestamps to neighbors. 34 | seed: random seed. 35 | """ 36 | sample_strategy = sample_strategy.lower() 37 | if sample_strategy not in ["recent", "uniform"]: 38 | raise ValueError("strategy must be 'recent' or 'uniform'") 39 | 40 | if sample_strategy == "recent": 41 | sample_strategy = SamplingPolicy.RECENT 42 | else: 43 | sample_strategy = SamplingPolicy.UNIFORM 44 | 45 | print("TemporalSampler: sample_strategy={}, num_snapshots={}, snapshot_time_window={}, prop_time={}".format( 46 | sample_strategy, num_snapshots, snapshot_time_window, prop_time)) 47 | 48 | self._sampler = _TemporalSampler( 49 | graph._dgraph, fanouts, sample_strategy, num_snapshots, 50 | snapshot_time_window, prop_time, seed) 51 | self._num_layers = len(fanouts) 52 | self._num_snapshots = num_snapshots 53 | 54 | if 'is_static' in kwargs and kwargs['is_static'] == True: 55 | self._is_static = True 56 | else: 57 | self._is_static = False 58 | 59 | def sample(self, target_vertices: np.ndarray, timestamps: np.ndarray, reverse=False) -> List[List[DGLBlock]]: 60 | """ 61 | Sample k-hop neighbors of given vertices. 62 | 63 | Args: 64 | target_vertices: root vertices to sample. CPU tensor. 65 | timestamps: timestamps of target vertices in the graph. CPU tensor. 66 | 67 | Returns: 68 | list of message flow graphs (# of graphs = # of snapshots) for 69 | each layer. 70 | """ 71 | if self._is_static: 72 | sampling_results = self._sampler.sample( 73 | target_vertices, 74 | np.full(target_vertices.shape, 75 | np.finfo(np.float32).max)) 76 | else: 77 | sampling_results = self._sampler.sample( 78 | target_vertices, timestamps) 79 | return self._to_dgl_block(sampling_results, reverse) 80 | 81 | def sample_layer(self, target_vertices: np.ndarray, timestamps: np.ndarray, 82 | layer: int, snapshot: int, to_dgl_block: bool = True) \ 83 | -> Union[DGLBlock, SamplingResult]: 84 | """ 85 | Sample neighbors of given vertices in a specific layer and snapshot. 86 | 87 | Args: 88 | target_vertices: root vertices to sample. CPU tensor. 89 | timestamps: timestamps of target vertices in the graph. CPU tensor. 90 | layer: layer to sample. 91 | snapshot: snapshot to sample. 92 | 93 | Returns: 94 | either a DGLBlock or a SamplingResult. 95 | """ 96 | sampling_result = self._sampler.sample_layer( 97 | target_vertices, timestamps, layer, snapshot) 98 | if to_dgl_block: 99 | return self._to_dgl_block_layer_snapshot(sampling_result) 100 | return sampling_result 101 | 102 | def _to_dgl_block(self, sampling_results: SamplingResult, reverse: bool) -> List[List[DGLBlock]]: 103 | mfgs = list() 104 | for sampling_results_layer in sampling_results: 105 | for r in sampling_results_layer: 106 | if not reverse: 107 | b = dgl.create_block( 108 | (r.col(), 109 | r.row()), 110 | num_src_nodes=r.num_src_nodes(), 111 | num_dst_nodes=r.num_dst_nodes()) 112 | b.srcdata['ID'] = torch.from_numpy(r.all_nodes()) 113 | b.edata['dt'] = torch.from_numpy(r.delta_timestamps()) 114 | b.srcdata['ts'] = torch.from_numpy(r.all_timestamps()) 115 | b.edata['ID'] = torch.from_numpy(r.eids()) 116 | else: 117 | b = dgl.create_block( 118 | (r.row(), 119 | r.col()), 120 | num_src_nodes=r.num_dst_nodes(), 121 | num_dst_nodes=r.num_src_nodes()) 122 | b.dstdata['ID'] = torch.from_numpy(r.all_nodes()) 123 | b.edata['dt'] = torch.from_numpy(r.delta_timestamps()) 124 | b.dstdata['ts'] = torch.from_numpy(r.all_timestamps()) 125 | b.edata['ID'] = torch.from_numpy(r.eids()) 126 | mfgs.append(b) 127 | mfgs = list(map(list, zip(*[iter(mfgs)] * self._num_snapshots))) 128 | mfgs.reverse() 129 | return mfgs 130 | 131 | def _to_dgl_block_layer_snapshot(self, sampling_result: SamplingResult) -> DGLBlock: 132 | mfg = dgl.create_block( 133 | (sampling_result.col(), 134 | sampling_result.row()), 135 | num_src_nodes=sampling_result.num_src_nodes(), 136 | num_dst_nodes=sampling_result.num_dst_nodes()) 137 | mfg.srcdata['ID'] = torch.from_numpy(sampling_result.all_nodes()) 138 | mfg.edata['dt'] = torch.from_numpy(sampling_result.delta_timestamps()) 139 | mfg.srcdata['ts'] = torch.from_numpy(sampling_result.all_timestamps()) 140 | mfg.edata['ID'] = torch.from_numpy(sampling_result.eids()) 141 | return mfg 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | brotlipy==0.7.0 2 | certifi==2022.5.18.1 3 | cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work 4 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 5 | cryptography @ file:///tmp/build/80754af9/cryptography_1652083738073/work 6 | dgl==0.8.1 7 | -e git+https://github.com/yuchenzhong/dynamic-graph-neural-network.git@60c14b0a0c950126922f634c4bda6b9529f168e0#egg=dgnn 8 | -e git+https://github.com/PeterSH6/TGNN-Staleness.git@1d7c569529517d2707d8ede8507089fa9e79715f#egg=gnnflow 9 | GPUtil==1.4.0 10 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work 11 | Jinja2==3.1.2 12 | joblib==1.1.0 13 | MarkupSafe==2.1.1 14 | mkl-fft==1.3.1 15 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work 16 | mkl-service==2.4.0 17 | networkx @ file:///opt/conda/conda-bld/networkx_1647437648384/work 18 | numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1652801679809/work 19 | pandas==1.4.2 20 | parameterized==0.8.1 21 | Pillow==9.0.1 22 | psutil==5.9.4 23 | pybind11==2.9.2 24 | pycodestyle==2.9.1 25 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 26 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work 27 | pyparsing==3.0.9 28 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 29 | python-dateutil==2.8.2 30 | pytz==2022.1 31 | PyYAML==6.0 32 | requests @ file:///opt/conda/conda-bld/requests_1641824580448/work 33 | scikit-learn==1.1.1 34 | scipy @ file:///tmp/build/80754af9/scipy_1641555001653/work 35 | six @ file:///tmp/build/80754af9/six_1644875935023/work 36 | threadpoolctl==3.1.0 37 | torch==1.11.0 38 | torch-cluster==1.6.0 39 | torch-geometric==2.0.4 40 | torch-scatter==2.0.9 41 | torch-sparse==0.6.13 42 | torch-spline-conv==1.2.1 43 | torchaudio==0.11.0 44 | torchvision==0.12.0 45 | tqdm @ file:///opt/conda/conda-bld/tqdm_1650891076910/work 46 | typing_extensions @ file:///opt/conda/conda-bld/typing_extensions_1647553014482/work 47 | urllib3 @ file:///opt/conda/conda-bld/urllib3_1650639997961/work 48 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # from https://github.com/amazon-research/tgl/blob/main/down.sh 3 | 4 | mkdir -p ../data/MOOC/ 5 | aria2c -x 16 -d ../data/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/edges.csv 6 | mkdir -p ../data/REDDIT/ 7 | aria2c -x 16 -d ../data/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edge_features.pt 8 | aria2c -x 16 -d ../data/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edges.csv 9 | aria2c -x 16 -d ../data/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/labels.csv 10 | mkdir -p ../data/WIKI 11 | aria2c -x 16 -d ../data/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edge_features.pt 12 | aria2c -x 16 -d ../data/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edges.csv 13 | aria2c -x 16 -d ../data/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/labels.csv 14 | mkdir -p ../data/LASTFM/ 15 | aria2c -x 16 -d ../data/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/edges.csv 16 | mkdir -p ../data/GDELT/ 17 | aria2c -x 16 -d ../data/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/node_features.pt 18 | aria2c -x 16 -d ../data/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/labels.csv 19 | aria2c -x 16 -d ../data/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edges.csv 20 | aria2c -x 16 -d ../data/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edge_features.pt 21 | mkdir -p ../data/MAG/ 22 | aria2c -x 16 -d ../data/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/labels.csv 23 | aria2c -x 16 -d ../data/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/edges.csv 24 | aria2c -x 16 -d ../data/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/node_features.pt 25 | 26 | -------------------------------------------------------------------------------- /scripts/pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import logging 4 | from typing import Iterable 5 | from dgl.heterograph import DGLBlock 6 | from gnnflow.utils import mfgs_to_cuda, node_to_dgl_blocks 7 | 8 | 9 | def sample(model, distributed, train_loader, sampler, queue_out): 10 | # logging.info('Sample Started') 11 | sample_time_sum = 0 12 | for i, (target_nodes, ts, eid) in enumerate(train_loader): 13 | if sampler is not None: 14 | model_name = type(model.module).__name__ if distributed else type(model).__name__ 15 | if model_name == 'APAN': 16 | mfgs = node_to_dgl_blocks(target_nodes, ts) 17 | target_pos = len(target_nodes) * 2 // 3 18 | block = sampler.sample( 19 | target_nodes[:target_pos], ts[:target_pos], reverse=True)[0][0] 20 | else: 21 | mfgs = sampler.sample(target_nodes, ts) 22 | block = None 23 | else: 24 | mfgs = node_to_dgl_blocks(target_nodes, ts) 25 | block = None 26 | queue_out.put((mfgs, block, eid)) 27 | # mfgs.share_memory() 28 | # logging.info('Sample done and send to feature fetching') 29 | # add signal that we are done 30 | queue_out.put(None) 31 | # logging.info('Sample in one epoch done') 32 | 33 | 34 | def left_training(mfgs, eid, model, device, cache, distributed, optimizer, criterion): 35 | mfgs_to_cuda(mfgs, device) 36 | # logging.info('move to cuda done') 37 | mfgs = cache.fetch_feature( 38 | mfgs, eid, target_edge_features=True) # because all use memory 39 | b = mfgs[0][0] # type: DGLBlock 40 | if distributed: 41 | model.module.memory.prepare_input(b) 42 | # model.module.last_updated = model.module.memory_updater(b) 43 | else: 44 | model.memory.prepare_input(b) 45 | # model.last_updated = model.memory_updater(b) 46 | last_updated = model.module.memory_updater(mfgs[0][0]) 47 | # logging.info('gnn mfgs {}'.format(mfgs)) 48 | optimizer.zero_grad() 49 | pred_pos, pred_neg = model(mfgs) 50 | loss = criterion(pred_pos, torch.ones_like(pred_pos)) 51 | loss += criterion(pred_neg, torch.zeros_like(pred_neg)) 52 | # total_loss += float(loss) * num_target_nodes 53 | loss.backward() 54 | optimizer.step() 55 | with torch.no_grad(): 56 | # use one function 57 | if distributed: 58 | model.module.memory.update_mem_mail( 59 | **last_updated, edge_feats=cache.target_edge_features.get(), 60 | neg_sample_ratio=1) 61 | else: 62 | model.memory.update_mem_mail( 63 | **last_updated, edge_feats=cache.target_edge_features.get(), 64 | neg_sample_ratio=1) 65 | 66 | 67 | def feature_fetching(cache, device, queue_in, queue_out, stream): 68 | # logging.info('feature fetching start') 69 | while True: 70 | # retrive from queue 71 | item = queue_in.get() 72 | # check for stop 73 | if item is None: 74 | queue_out.put(item) 75 | break 76 | global mfgs 77 | mfgs, block, eid = item 78 | # logging.info('feature mfgs {}'.format(mfgs)) 79 | # with torch.cuda.stream(stream): 80 | mfgs_to_cuda(mfgs, device) 81 | # logging.info('move to cuda done') 82 | mfgs = cache.fetch_feature( 83 | mfgs, eid, target_edge_features=True) # because all use memory 84 | queue_out.put((mfgs, block)) 85 | # logging.info('fetch feature done') 86 | # logging.info('feature fetching done') 87 | 88 | 89 | def memory_fetching(model, distributed, queue_in, queue_out, stream): 90 | while True: 91 | # retrive from queue 92 | item = queue_in.get() 93 | # check for stop 94 | if item is None: 95 | queue_out.put(item) 96 | break 97 | mfgs, block = item 98 | # logging.info('memory mfgs {}'.format(mfgs)) 99 | # with torch.cuda.stream(stream): 100 | b = mfgs[0][0] # type: DGLBlock 101 | if distributed: 102 | model.module.memory.prepare_input(b) 103 | # model.module.last_updated = model.module.memory_updater(b) 104 | else: 105 | model.memory.prepare_input(b) 106 | # model.last_updated = model.memory_updater(b) 107 | 108 | queue_out.put((mfgs, block)) 109 | 110 | # TODO: first test GNN first and Memory Last 111 | 112 | 113 | def gnn_training(model, optimizer, criterion, queue_in, queue_out, stream): 114 | # logging.info('gnn training start') 115 | neg_sample_ratio = 1 116 | while True: 117 | # retrive from queue 118 | item = queue_in.get() 119 | # check for stop 120 | if item is None: 121 | queue_out.put(item) 122 | break 123 | mfgs, block = item 124 | # with torch.cuda.stream(stream): 125 | last_updated = model.module.memory_updater(mfgs[0][0]) 126 | # logging.info('gnn mfgs {}'.format(mfgs)) 127 | optimizer.zero_grad() 128 | pred_pos, pred_neg = model(mfgs) 129 | loss = criterion(pred_pos, torch.ones_like(pred_pos)) 130 | loss += criterion(pred_neg, torch.zeros_like(pred_neg)) 131 | # total_loss += float(loss) * num_target_nodes 132 | loss.backward() 133 | optimizer.step() 134 | queue_out.put((last_updated, block)) # TODO: may not need? 135 | 136 | 137 | def memory_update(model, distributed, cache, queue_in, stream): 138 | # logging.info('memory update start') 139 | while True: 140 | # retrive from queue 141 | item = queue_in.get() 142 | # check for stop 143 | if item is None: 144 | break 145 | last_updated, block = item 146 | # NB: no need to do backward here 147 | # with torch.cuda.stream(stream): 148 | with torch.no_grad(): 149 | # use one function 150 | if distributed: 151 | model.module.memory.update_mem_mail( 152 | **last_updated, edge_feats=cache.target_edge_features.get(), 153 | neg_sample_ratio=1, block=block) 154 | else: 155 | model.memory.update_mem_mail( 156 | **last_updated, edge_feats=cache.target_edge_features.get(), 157 | neg_sample_ratio=1, block=block) 158 | -------------------------------------------------------------------------------- /scripts/pipeline_distributed.py: -------------------------------------------------------------------------------- 1 | 2 | from queue import Queue 3 | import torch 4 | import logging 5 | from typing import Iterable 6 | from dgl.heterograph import DGLBlock 7 | from gnnflow.distributed.dist_sampler import DistributedTemporalSampler 8 | from gnnflow.temporal_sampler import TemporalSampler 9 | from gnnflow.utils import mfgs_to_cuda 10 | 11 | 12 | def local_sample(train_loader, sampler: DistributedTemporalSampler, queue_out: Queue): 13 | # logging.info('Sample Started') 14 | sample_time_sum = 0 15 | # For TGN 16 | layer = 0 17 | snapshot = 0 18 | for i, (target_nodes, ts, eid) in enumerate(train_loader): 19 | futures, masks = sampler.sample_layer_first( 20 | target_nodes, ts, layer, snapshot) 21 | queue_out.put((futures, masks, target_nodes, ts, eid)) 22 | # mfgs.share_memory() 23 | # logging.info('Sample done and send to feature fetching') 24 | # add signal that we are done 25 | queue_out.put(None) 26 | # logging.info('Sample in one epoch done') 27 | 28 | 29 | def collect_sample_results(sampler: DistributedTemporalSampler, queue_in: Queue, queue_out: Queue): 30 | layer = 0 31 | while True: 32 | # retrive from queue 33 | item = queue_in.get() 34 | # check for stop 35 | if item is None: 36 | queue_out.put(item) 37 | break 38 | mfgs = [] 39 | mfgs.append([]) 40 | futures, masks, target_nodes, ts, eids = item 41 | mfgs[layer].append(sampler.sample_layer_collect( 42 | futures, masks, target_nodes, ts, layer)) 43 | mfgs.reverse() 44 | queue_out.put((mfgs, eids)) 45 | 46 | 47 | def sample(train_loader, sampler, queue_out): 48 | # logging.info('Sample Started') 49 | sample_time_sum = 0 50 | for i, (target_nodes, ts, eid) in enumerate(train_loader): 51 | mfgs = sampler.sample(target_nodes, ts) 52 | queue_out.put((mfgs, eid)) 53 | # mfgs.share_memory() 54 | # logging.info('Sample done and send to feature fetching') 55 | # add signal that we are done 56 | queue_out.put(None) 57 | # logging.info('Sample in one epoch done') 58 | 59 | 60 | def feature_fetching(cache, device, queue_in, queue_out): 61 | # logging.info('feature fetching start') 62 | while True: 63 | # retrive from queue 64 | item = queue_in.get() 65 | # check for stop 66 | if item is None: 67 | queue_out.put(item) 68 | break 69 | mfgs, eid = item 70 | # logging.info('feature mfgs {}'.format(mfgs)) 71 | mfgs_to_cuda(mfgs, device) 72 | # logging.info('move to cuda done') 73 | mfgs = cache.fetch_feature( 74 | mfgs, eid, target_edge_features=True) # because all use memory 75 | queue_out.put(mfgs) 76 | # logging.info('fetch feature done') 77 | # logging.info('feature fetching done') 78 | 79 | 80 | def feature_fetching_local(cache, device, queue_in, queue_out): 81 | while True: 82 | # retrive from queue 83 | item = queue_in.get() 84 | # check for stop 85 | if item is None: 86 | queue_out.put(item) 87 | break 88 | mfgs, eid = item 89 | # logging.info('feature mfgs {}'.format(mfgs)) 90 | mfgs_to_cuda(mfgs, device) 91 | # logging.info('move to cuda done') 92 | futures = cache.fetch_feature_local( 93 | mfgs, eid, target_edge_features=True) # because all use memory 94 | queue_out.put((mfgs, eid, futures)) 95 | 96 | 97 | def feature_fetching_collect(cache, device, queue_in, queue_out): 98 | while True: 99 | # retrive from queue 100 | item = queue_in.get() 101 | # check for stop 102 | if item is None: 103 | queue_out.put(item) 104 | break 105 | mfgs, eid, futures = item 106 | # logging.info('move to cuda done') 107 | mfgs = cache.fetch_feature_collect( 108 | *futures, mfgs, eid, target_edge_features=True) 109 | queue_out.put(mfgs) 110 | 111 | 112 | def memory_fetching(model, distributed, queue_in, queue_out): 113 | logging.info('memory fetching start') 114 | while True: 115 | # retrive from queue 116 | item = queue_in.get() 117 | # check for stop 118 | if item is None: 119 | queue_out.put(item) 120 | break 121 | mfgs = item 122 | # logging.info('memory mfgs {}'.format(mfgs)) 123 | b = mfgs[0][0] # type: DGLBlock 124 | if distributed: 125 | model.module.memory.prepare_input(b) 126 | # model.module.last_updated = model.module.memory_updater(b) 127 | else: 128 | model.memory.prepare_input(b) 129 | # model.last_updated = model.memory_updater(b) 130 | 131 | queue_out.put(mfgs) 132 | 133 | # TODO: first test GNN first and Memory Last 134 | 135 | 136 | def gnn_training(model, optimizer, criterion, queue_in, queue_out): 137 | # logging.info('gnn training start') 138 | neg_sample_ratio = 1 139 | while True: 140 | # retrive from queue 141 | item = queue_in.get() 142 | # check for stop 143 | if item is None: 144 | queue_out.put(item) 145 | break 146 | mfgs = item 147 | last_updated = model.module.memory_updater(mfgs[0][0]) 148 | # logging.info('gnn mfgs {}'.format(mfgs)) 149 | optimizer.zero_grad() 150 | pred_pos, pred_neg = model(mfgs) 151 | loss = criterion(pred_pos, torch.ones_like(pred_pos)) 152 | loss += criterion(pred_neg, torch.zeros_like(pred_neg)) 153 | # total_loss += float(loss) * num_target_nodes 154 | loss.backward() 155 | optimizer.step() 156 | queue_out.put(last_updated) # TODO: may not need? 157 | 158 | 159 | def memory_update(model, distributed, cache, queue_in): 160 | # logging.info('memory update start') 161 | while True: 162 | # retrive from queue 163 | item = queue_in.get() 164 | # check for stop 165 | if item is None: 166 | break 167 | last_updated = item 168 | # NB: no need to do backward here 169 | with torch.no_grad(): 170 | # use one function 171 | if distributed: 172 | model.module.memory.update_mem_mail( 173 | # cache must use a queue to maintain right version 174 | **last_updated, edge_feats=cache.target_edge_features.get(), 175 | neg_sample_ratio=1) 176 | else: 177 | model.memory.update_mem_mail( 178 | **last_updated, edge_feats=cache.target_edge_features.get(), 179 | neg_sample_ratio=1) 180 | -------------------------------------------------------------------------------- /scripts/run_offline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL=$1 4 | DATA=$2 5 | NPROC_PER_NODE=${3:-1} 6 | CACHE="${4:-LRUCache}" 7 | EDGE_CACHE_RATIO="${5:-0}" # default 0% of cache 8 | NODE_CACHE_RATIO="${6:-0}" # default 0% of cache 9 | TIME_WINDOW="${7:-0}" # default 0 10 | 11 | if [[ $NPROC_PER_NODE -gt 1 ]]; then 12 | cmd="torchrun \ 13 | --nnodes=1 --nproc_per_node=$NPROC_PER_NODE \ 14 | --standalone \ 15 | offline_edge_prediction_pipethread.py --model $MODEL --data $DATA \ 16 | --cache $CACHE --edge-cache-ratio $EDGE_CACHE_RATIO \ 17 | --node-cache-ratio $NODE_CACHE_RATIO --snapshot-time-window $TIME_WINDOW \ 18 | --ingestion-batch-size 10000000" 19 | else 20 | cmd="python offline_edge_prediction_pipethread.py --model $MODEL --data $DATA \ 21 | --cache $CACHE --edge-cache-ratio $EDGE_CACHE_RATIO \ 22 | --node-cache-ratio $NODE_CACHE_RATIO --snapshot-time-window $TIME_WINDOW \ 23 | --ingestion-batch-size 10000000" 24 | fi 25 | 26 | echo $cmd 27 | OMP_NUM_THREADS=8 exec $cmd > ${MODEL}_${DATA}_${NPROC_PER_NODE}_MSPipe.log 2>&1 28 | -------------------------------------------------------------------------------- /scripts/run_offline_dist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | INTERFACE="enp225s0" 3 | 4 | MODEL=$1 5 | DATA=$2 6 | CACHE="${3:-LRUCache}" 7 | EDGE_CACHE_RATIO="${4:-0}" # default 0% of cache 8 | NODE_CACHE_RATIO="${5:-0}" # default 0% of cache 9 | 10 | HOST_NODE_ADDR=10.28.1.31 11 | HOST_NODE_PORT=29400 12 | NNODES=2 13 | NPROC_PER_NODE=4 14 | 15 | CURRENT_NODE_IP=$(ip -4 a show dev ${INTERFACE} | grep inet | cut -d " " -f6 | cut -d "/" -f1) 16 | if [ $CURRENT_NODE_IP = $HOST_NODE_ADDR ]; then 17 | IS_HOST=true 18 | else 19 | IS_HOST=false 20 | fi 21 | 22 | export NCCL_SOCKET_IFNAME=${INTERFACE} 23 | export GLOO_SOCKET_IFNAME=${INTERFACE} 24 | export TP_SOCKET_IFNAME=${INTERFACE} 25 | 26 | cmd="torchrun \ 27 | --nnodes=$NNODES --nproc_per_node=$NPROC_PER_NODE \ 28 | --rdzv_id=1234 --rdzv_backend=c10d \ 29 | --rdzv_endpoint=$HOST_NODE_ADDR:$HOST_NODE_PORT \ 30 | --rdzv_conf is_host=$IS_HOST \ 31 | offline_edge_prediction_pipethread.py --model $MODEL --data $DATA \ 32 | --cache $CACHE --edge-cache-ratio $EDGE_CACHE_RATIO --node-cache-ratio $NODE_CACHE_RATIO\ 33 | --ingestion-batch-size 10000000 --epoch 10" 34 | 35 | rm -rf /dev/shm/rmm_pool_* 36 | rm -rf /dev/shm/edge_feats 37 | rm -rf /dev/shm/node_feats 38 | 39 | echo $cmd 40 | OMP_NUM_THREADS=8 exec $cmd > ${MODEL}_${DATA}_${CACHE}_${EDGE_CACHE_RATIO}_${NODE_CACHE_RATIO}_${NNODES}_${NPROC_PER_NODE}.log 2>&1 41 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | from queue import Queue 4 | from threading import Lock 5 | from time import sleep 6 | from typing import List 7 | import torch 8 | 9 | from gnnflow.utils import mfgs_to_cuda, node_to_dgl_blocks 10 | 11 | # global iter_mem_update 12 | iter_mem_update = 0 13 | def training_batch(model, sampler, cache, target_nodes, ts, eid, device, distributed, optimizer, criterion, Stream: torch.cuda.Stream, queue: Queue, lock_pool: List[Lock], i: int, rank: int, most_simliar=None, avg_cos_list=[]): 14 | with torch.cuda.stream(Stream): 15 | with lock_pool[0]: 16 | if sampler is not None: 17 | model_name = type(model.module).__name__ if distributed else type(model).__name__ 18 | if model_name == 'APAN': 19 | mfgs = node_to_dgl_blocks(target_nodes, ts) 20 | target_pos = len(target_nodes) * 2 // 3 21 | block = sampler.sample( 22 | target_nodes[:target_pos], ts[:target_pos], reverse=True)[0][0] 23 | else: 24 | mfgs = sampler.sample(target_nodes, ts) 25 | block = None 26 | else: 27 | mfgs = node_to_dgl_blocks(target_nodes, ts) 28 | block = None 29 | # lock_pool[0].release() 30 | with lock_pool[1]: 31 | mfgs_to_cuda(mfgs, device) 32 | mfgs = cache.fetch_feature( 33 | mfgs, eid, target_edge_features=True) # because all use memory 34 | with lock_pool[2]: 35 | b = mfgs[0][0] # type: DGLBlock 36 | global iter_mem_update 37 | if distributed: 38 | model.module.memory.prepare_input(b, most_simliar) 39 | else: 40 | model.memory.prepare_input(b, most_simliar) 41 | # global iter_mem_update 42 | # if rank == 0: 43 | # logging.info('iter {} staleness {}'.format(i, i - iter_mem_update)) 44 | with lock_pool[3]: 45 | # if rank == 0: 46 | # logging.info("gnn iter: {}".format(i)) 47 | with torch.cuda.stream(torch.cuda.default_stream()): 48 | if distributed: 49 | last_updated = model.module.memory_updater(mfgs[0][0]) 50 | else: 51 | last_updated = model.memory_updater(mfgs[0][0]) 52 | 53 | optimizer.zero_grad() 54 | 55 | pred_pos, pred_neg = model(mfgs) 56 | # logging.info('pred_pos{}'.format(pred_pos)) 57 | 58 | loss = criterion(pred_pos, torch.ones_like(pred_pos)) 59 | 60 | loss += criterion(pred_neg, torch.zeros_like(pred_neg)) 61 | # logging.info("iter: {} loss: {}".format(i, loss)) 62 | # total_loss += float(loss) * num_target_nodes 63 | loss.backward() 64 | optimizer.step() 65 | # if rank == 0: 66 | # logging.info("gnn iter: {}".format(i)) 67 | with lock_pool[4]: 68 | # if rank == 0: 69 | # logging.info("update iter: {}".format(i)) 70 | with torch.no_grad(): 71 | # use one function 72 | if distributed: 73 | model.module.memory.update_mem_mail( 74 | **last_updated, edge_feats=cache.target_edge_features.get(), 75 | neg_sample_ratio=1, block=block) 76 | else: 77 | model.memory.update_mem_mail( 78 | **last_updated, edge_feats=cache.target_edge_features.get(), 79 | neg_sample_ratio=1, block=block) 80 | 81 | # global iter_mem_update 82 | iter_mem_update = i 83 | # queue.get() 84 | # if rank == 0: 85 | # logging.info('current iteration: {}'.format(i)) 86 | # logging.info('current update: {}'.format(iter_mem_update)) 87 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | 5 | import torch.utils 6 | from setuptools import Extension, find_packages, setup 7 | from setuptools.command.build_ext import build_ext 8 | 9 | gnnflow_lib = Extension( 10 | "libgnnflow", sources=[] 11 | ) 12 | 13 | curdir = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | 16 | def get_cmake_bin(): 17 | cmake_bin = "cmake" 18 | try: 19 | subprocess.check_output([cmake_bin, "--version"]) 20 | except OSError: 21 | raise RuntimeError( 22 | "Cannot find CMake executable. " 23 | "Please install CMake and try again." 24 | ) 25 | return cmake_bin 26 | 27 | 28 | class CustomBuildExt(build_ext): 29 | def build_extensions(self): 30 | cmake_bin = get_cmake_bin() 31 | 32 | debug = os.environ.get("DEBUG", "0") 33 | config = 'Debug' if debug == "1" else 'Release' 34 | print("Building with CMake config: {}".format(config)) 35 | 36 | ext_name = self.extensions[0].name 37 | build_dir = self.get_ext_fullpath(ext_name).replace( 38 | self.get_ext_filename(ext_name), '') 39 | build_dir = os.path.abspath(build_dir) 40 | 41 | cmake_args = [ 42 | "-DCMAKE_BUILD_TYPE={}".format(config), 43 | "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", 44 | "-DPYTHON_EXECUTABLE:FILEPATH={}".format(sys.executable), 45 | "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}".format(build_dir), 46 | "-DCMAKE_PREFIX_PATH={}".format(torch.utils.cmake_prefix_path), 47 | ] 48 | 49 | cmake_build_args = ['--config', config, '--', '-j'] 50 | 51 | if not os.path.exists(self.build_lib): 52 | os.makedirs(self.build_lib) 53 | 54 | os.chdir("build") 55 | 56 | try: 57 | subprocess.check_call([cmake_bin, "..", *cmake_args]) 58 | subprocess.check_call( 59 | [cmake_bin, "--build", ".", *cmake_build_args]) 60 | except subprocess.CalledProcessError as e: 61 | raise RuntimeError("CMake build failed") from e 62 | 63 | os.chdir(curdir) 64 | 65 | 66 | require_list = ["torch", "numpy"] 67 | 68 | test_require_list = ["unittest", "parameterized"] 69 | 70 | setup( 71 | name="gnnflow", 72 | version="0.0.1", 73 | description="A comprehensive framework for dynamic graph neural networks", 74 | license="Apache 2.0", 75 | url="https://github.com/jasperzhong/GNNFlow", 76 | packages=find_packages(exclude=("tests")), 77 | ext_modules=[gnnflow_lib], 78 | cmdclass={"build_ext": CustomBuildExt}, 79 | classifiers=[ 80 | "Programming Language :: Python :: 3", 81 | "License :: OSI Approved :: Apache Software License", 82 | "Operating System :: OS Independent" 83 | "Topic :: Machine Learning Package" 84 | ], 85 | python_requires='>=3.6', 86 | install_requires=require_list, 87 | tests_require=test_require_list, 88 | ) 89 | -------------------------------------------------------------------------------- /tests/test_build_graph.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import unittest 3 | 4 | import numpy as np 5 | from parameterized import parameterized 6 | 7 | from gnnflow.config import get_default_config 8 | from gnnflow.utils import build_dynamic_graph, load_dataset 9 | 10 | MB = 1 << 20 11 | GB = 1 << 30 12 | 13 | _, default_config = get_default_config("TGN", "REDDIT") 14 | 15 | 16 | class TestBuildGraph(unittest.TestCase): 17 | 18 | @parameterized.expand(itertools.product(["cuda", "unified", "pinned", "shared"])) 19 | def test_build_graph(self, mem_resource_type): 20 | """ 21 | Test building a dynamic graph from edges.csv(REDDIT) 22 | Only use training data to build a graph 23 | """ 24 | train_df, _, _, df = load_dataset(dataset="REDDIT") 25 | config = default_config.copy() 26 | config["mem_resource_type"] = mem_resource_type 27 | config["undirected"] = False 28 | dgraph = build_dynamic_graph(**config, dataset_df=train_df) 29 | 30 | train_edge_end = df[df['ext_roll'].gt(0)].index[0] 31 | srcs = np.array(df['src'][:train_edge_end], dtype=int) 32 | srcs = np.unique(srcs) 33 | 34 | dsts = np.array(df['dst'][:train_edge_end], dtype=int) 35 | dsts = np.unique(dsts) 36 | 37 | self.assertEqual(dgraph.max_vertex_id(), 38 | max(np.max(srcs), np.max(dsts))) 39 | 40 | # Test edges 41 | for src in srcs: 42 | df = df[:train_edge_end] 43 | out_edges = np.array(df[df['src'] == src]['dst'], dtype=int) 44 | ts = np.array(df[df['src'] == src]['time']) 45 | ts = np.flip(ts) 46 | 47 | graph_out_edges, graph_ts, _ = dgraph.get_temporal_neighbors(src) 48 | self.assertEqual(len(out_edges), len(graph_out_edges)) 49 | self.assertEqual(len(graph_out_edges), dgraph.out_degree([src])) 50 | self.assertTrue(np.allclose(ts, graph_ts)) 51 | 52 | @parameterized.expand(itertools.product(["cuda", "unified", "pinned", "shared"])) 53 | def test_build_graph_undirected(self, mem_resource_type): 54 | """ 55 | Test building a dynamic graph from edges.csv(REDDIT) 56 | Only use training data to build a graph 57 | """ 58 | train_df, _, _, df = load_dataset(dataset="REDDIT") 59 | config = default_config.copy() 60 | config["mem_resource_type"] = mem_resource_type 61 | config["undirected"] = True 62 | dgraph = build_dynamic_graph(**config, dataset_df=train_df) 63 | 64 | train_edge_end = df[df['ext_roll'].gt(0)].index[0] 65 | srcs = np.array(df['src'][:train_edge_end], dtype=int) 66 | srcs = np.unique(srcs) 67 | 68 | dsts = np.array(df['dst'][:train_edge_end], dtype=int) 69 | dsts = np.unique(dsts) 70 | 71 | srcs = np.concatenate((srcs, dsts)) 72 | 73 | self.assertEqual(dgraph.max_vertex_id(), 74 | max(np.max(srcs), np.max(dsts))) 75 | 76 | # Test edges 77 | for src in srcs: 78 | df = df[:train_edge_end] 79 | out_edges = np.array(df[df['src'] == src]['dst'], dtype=int) 80 | out_edges_reverse = np.array( 81 | df[df['dst'] == src]['src'], dtype=int) 82 | out_edges = np.concatenate((out_edges, out_edges_reverse)) 83 | ts = np.array(df[df['src'] == src]['time']) 84 | ts_reverse = np.array(df[df['dst'] == src]['time']) 85 | ts = np.concatenate((ts, ts_reverse)) 86 | ts = np.flip(ts) 87 | 88 | graph_out_edges, graph_ts, _ = dgraph.get_temporal_neighbors(src) 89 | self.assertEqual(len(out_edges), len(graph_out_edges)) 90 | self.assertEqual(len(graph_out_edges), dgraph.out_degree([src])) 91 | self.assertTrue(np.allclose(ts, graph_ts)) 92 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import time 2 | import unittest 3 | 4 | import numpy as np 5 | from torch.utils.data import BatchSampler, DataLoader, SequentialSampler 6 | 7 | from gnnflow.data import (EdgePredictionDataset, RandomStartBatchSampler, 8 | default_collate_ndarray) 9 | from gnnflow.utils import get_batch, load_dataset 10 | 11 | 12 | class TestDataset(unittest.TestCase): 13 | def test_loader(self, num_workers=0): 14 | train_df, val_df, test_df, df = load_dataset('REDDIT') 15 | 16 | ds = EdgePredictionDataset(train_df) 17 | 18 | sampler = BatchSampler(SequentialSampler( 19 | ds), batch_size=600, drop_last=False) 20 | 21 | a = DataLoader(dataset=ds, sampler=sampler, 22 | collate_fn=default_collate_ndarray, 23 | num_workers=num_workers) 24 | ti = 0 25 | ite = iter(get_batch(train_df, batch_size=600)) 26 | start_loader = 0 27 | avg_loader = 0 28 | avg_batch = 0 29 | for target, ts, eid in a: 30 | end_loader = time.time() 31 | loader_time = end_loader - start_loader 32 | start_batch = time.time() 33 | target1, ts1, eid1 = ite.__next__() 34 | end_batch = time.time() 35 | batch_time = end_batch - start_batch 36 | self.assertTrue(np.array_equal(target[:1200], target1[:1200])) 37 | self.assertTrue(np.array_equal(ts[:1200], ts1[:1200])) 38 | self.assertTrue(np.array_equal(eid[:1200], eid1[:1200])) 39 | ti = ti + 1 40 | if ti > 10: 41 | break 42 | if ti > 1: 43 | avg_loader += loader_time 44 | avg_batch += batch_time 45 | start_loader = time.time() 46 | 47 | print("avg loader time with {} workers: {}".format( 48 | num_workers, avg_loader / 10)) 49 | print("avg batch time: {}".format(avg_batch / 10)) 50 | 51 | def test_sampler(self): 52 | train_df, val_df, test_df, df = load_dataset('REDDIT') 53 | 54 | ds = EdgePredictionDataset(train_df) 55 | 56 | sampler = RandomStartBatchSampler(SequentialSampler( 57 | ds), batch_size=600, drop_last=False, num_chunks=8) 58 | 59 | a = DataLoader(dataset=ds, sampler=sampler, 60 | collate_fn=default_collate_ndarray, 61 | num_workers=0) 62 | ti = 0 63 | ite = iter(get_batch(train_df, batch_size=600)) 64 | start_loader = 0 65 | avg_loader = 0 66 | avg_batch = 0 67 | sampler.reset() 68 | for target, ts, eid in a: 69 | end_loader = time.time() 70 | loader_time = end_loader - start_loader 71 | start_batch = time.time() 72 | target1, ts1, eid1 = ite.__next__() 73 | end_batch = time.time() 74 | batch_time = end_batch - start_batch 75 | 76 | ti = ti + 1 77 | if ti > 10: 78 | break 79 | if ti > 1: 80 | avg_loader += loader_time 81 | avg_batch += batch_time 82 | start_loader = time.time() 83 | 84 | print("avg loader time: {}".format( 85 | avg_loader / 10)) 86 | print("avg batch time: {}".format(avg_batch / 10)) 87 | 88 | 89 | if __name__ == "__main__": 90 | unittest.main() 91 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import numpy as np 5 | from gnnflow.config import get_default_config 6 | from gnnflow.models.dgnn import DGNN 7 | from gnnflow.models.gat import GAT 8 | from gnnflow.models.graphsage import SAGE 9 | from gnnflow.temporal_sampler import TemporalSampler 10 | from gnnflow.utils import (build_dynamic_graph, get_batch, load_dataset, 11 | load_feat, mfgs_to_cuda, prepare_input) 12 | from gnnflow.cache import LFUCache 13 | 14 | 15 | class TestModel(unittest.TestCase): 16 | def test_tgn_forward(self): 17 | node_feats, edge_feats = load_feat('REDDIT') 18 | train_df, val_df, test_df, df = load_dataset('REDDIT') 19 | model_config, data_config = get_default_config('TGN', 'REDDIT') 20 | dgraph = build_dynamic_graph(**data_config, dataset_df=df) 21 | gnn_dim_node = 0 if node_feats is None else node_feats.shape[1] 22 | gnn_dim_edge = 0 if edge_feats is None else edge_feats.shape[1] 23 | batch_size = 600 24 | device = torch.device("cuda:0") 25 | model = DGNN( 26 | gnn_dim_node, gnn_dim_edge, **model_config, 27 | num_nodes=dgraph.num_vertices(), 28 | memory_device=device).to(device) 29 | 30 | sampler = TemporalSampler(dgraph, [10]) 31 | it = iter(get_batch(train_df, batch_size)) 32 | target_nodes, ts, eid = it.__next__() 33 | mfgs = sampler.sample(target_nodes, ts) 34 | 35 | mfgs = prepare_input(mfgs, node_feats, edge_feats) 36 | mfgs_to_cuda(mfgs, device) 37 | 38 | pred_pos, pred_neg = model(mfgs, eid=eid, edge_feats=edge_feats, 39 | neg_sample_ratio=0) 40 | 41 | def test_graph_sage_forward(self): 42 | edge_feats = torch.randn(411749, 172) 43 | node_feats = torch.randn(7144, 172) 44 | train_df, val_df, test_df, df = load_dataset('MOOC') 45 | model_config, data_config = get_default_config('TGN', 'MOOC') 46 | dgraph = build_dynamic_graph(**data_config, dataset_df=df) 47 | gnn_dim_node = 0 if node_feats is None else node_feats.shape[1] 48 | batch_size = 600 49 | device = torch.device("cuda:0") 50 | model = SAGE(gnn_dim_node, 100).to(device) 51 | 52 | sampler = TemporalSampler(dgraph, [3, 3, 3]) 53 | it = iter(get_batch(train_df, batch_size)) 54 | target_nodes, ts, eid = it.__next__() 55 | mfgs = sampler.sample(target_nodes, np.full( 56 | target_nodes.shape, np.finfo(np.float32).max)) 57 | 58 | cache = LFUCache(0.0005, 7144, 59 | dgraph.num_edges(), device, 60 | node_feats, edge_feats, 61 | 172, 172, 62 | None, 63 | None, 64 | None, False, 0) 65 | mfgs_to_cuda(mfgs, device) 66 | mfgs = cache.fetch_feature(mfgs, eid) 67 | mfgs = prepare_input(mfgs, node_feats, edge_feats) 68 | 69 | pred_pos, pred_neg = model(mfgs, neg_sample_ratio=0) 70 | 71 | def test_gat_forward(self): 72 | edge_feats = torch.randn(411749, 172) 73 | node_feats = torch.randn(7144, 172) 74 | train_df, val_df, test_df, df = load_dataset('MOOC') 75 | model_config, data_config = get_default_config('TGN', 'MOOC') 76 | dgraph = build_dynamic_graph(**data_config, dataset_df=df) 77 | gnn_dim_node = 0 if node_feats is None else node_feats.shape[1] 78 | batch_size = 600 79 | device = torch.device("cuda:0") 80 | model = GAT(gnn_dim_node, 100, allow_zero_in_degree=True).to(device) 81 | 82 | sampler = TemporalSampler(dgraph, [3, 3]) 83 | it = iter(get_batch(train_df, batch_size)) 84 | target_nodes, ts, eid = it.__next__() 85 | mfgs = sampler.sample(target_nodes, np.full( 86 | target_nodes.shape, np.finfo(np.float32).max)) 87 | 88 | mfgs = prepare_input(mfgs, node_feats, edge_feats) 89 | mfgs_to_cuda(mfgs, device) 90 | 91 | pred_pos, pred_neg = model(mfgs, neg_sample_ratio=0) 92 | 93 | 94 | if __name__ == "__main__": 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def is_sorted(tensor: torch.Tensor): 4 | return torch.all(torch.ge(tensor[1:], tensor[:-1])) 5 | -------------------------------------------------------------------------------- /tgl/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5.1) 2 | project(tgl_gc) 3 | add_subdirectory(third_party/pybind11) 4 | pybind11_add_module(gen_graph_multithread gen_graph_multithread.cc) 5 | SET( CMAKE_CXX_FLAGS "-mcmodel=large -std=c++11 -O3 -lpthread") -------------------------------------------------------------------------------- /tgl/run_tgl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL=$1 4 | DATA=$2 5 | NPROC_PER_NODE=${3:-1} 6 | CACHE="${4:-LRUCache}" 7 | EDGE_CACHE_RATIO="${5:-0}" # default 0% of cache 8 | NODE_CACHE_RATIO="${6:-0}" # default 0% of cache 9 | TIME_WINDOW="${7:-0}" # default 0 10 | 11 | if [[ $NPROC_PER_NODE -gt 1 ]]; then 12 | cmd="torchrun \ 13 | --nnodes=1 --nproc_per_node=$NPROC_PER_NODE \ 14 | --standalone \ 15 | offline_tgl_presample.py --model $MODEL --data $DATA \ 16 | --cache $CACHE --edge-cache-ratio $EDGE_CACHE_RATIO \ 17 | --node-cache-ratio $NODE_CACHE_RATIO --snapshot-time-window $TIME_WINDOW \ 18 | --ingestion-batch-size 10000000" 19 | else 20 | cmd="python offline_tgl_presample.py --model $MODEL --data $DATA \ 21 | --cache $CACHE --edge-cache-ratio $EDGE_CACHE_RATIO \ 22 | --node-cache-ratio $NODE_CACHE_RATIO --snapshot-time-window $TIME_WINDOW \ 23 | --ingestion-batch-size 10000000" 24 | fi 25 | 26 | echo $cmd 27 | OMP_NUM_THREADS=8 exec $cmd > TGL_${MODEL}_${DATA}_${NPROC_PER_NODE}_presample.log 2>&1 28 | -------------------------------------------------------------------------------- /tgl/setup_tgl.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from setuptools import setup 3 | from pybind11.setup_helpers import Pybind11Extension 4 | 5 | ext_modules = [ 6 | Pybind11Extension("sampler_core", 7 | ['sampler_core.cpp'], 8 | extra_compile_args = ['-fopenmp'], 9 | extra_link_args = ['-fopenmp'],), 10 | ] 11 | 12 | setup( 13 | name = "sampler_core", 14 | version = "0.0.1", 15 | author = "Hongkuan Zhou", 16 | author_email = "hongkuaz@usc.edu", 17 | url = "https://tedzhouhk.github.io/about/", 18 | description = "Parallel Sampling for Temporal Graphs", 19 | ext_modules = ext_modules, 20 | ) -------------------------------------------------------------------------------- /tgl/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from tgl.sampler_core import ParallelSampler 4 | import torch 5 | import dgl 6 | 7 | def load_graph(d): 8 | # df = pd.read_csv('DATA/{}/edges.csv'.format(d)) 9 | g = np.load('../data/{}/ext_full.npz'.format(d)) 10 | return g 11 | 12 | def to_dgl_blocks(ret, hist, reverse=False): 13 | mfgs = list() 14 | for r in ret: 15 | if not reverse: 16 | b = dgl.create_block( 17 | (r.col(), r.row()), num_src_nodes=r.dim_in(), num_dst_nodes=r.dim_out()) 18 | b.srcdata['ID'] = torch.from_numpy(r.nodes()) 19 | b.edata['dt'] = torch.from_numpy(r.dts())[b.num_dst_nodes():] 20 | b.srcdata['ts'] = torch.from_numpy(r.ts()) 21 | else: 22 | b = dgl.create_block( 23 | (r.row(), r.col()), num_src_nodes=r.dim_out(), num_dst_nodes=r.dim_in()) 24 | b.dstdata['ID'] = torch.from_numpy(r.nodes()) 25 | b.edata['dt'] = torch.from_numpy(r.dts())[b.num_src_nodes():] 26 | b.dstdata['ts'] = torch.from_numpy(r.ts()) 27 | b.edata['ID'] = torch.from_numpy(r.eid()) 28 | 29 | mfgs.append(b) 30 | mfgs = list(map(list, zip(*[iter(mfgs)] * hist))) 31 | mfgs.reverse() 32 | return mfgs 33 | 34 | --------------------------------------------------------------------------------