├── .gitignore ├── README.md ├── build.sh ├── dataset ├── README.md ├── convert_to_bin.py ├── gen_legion_xtrapulp_fomat.cpp ├── gen_sets.py ├── webgraph-3.5.2.jar └── webgraph-3.6.8-deps.tar.gz ├── detail_parameter_settings └── README.md ├── graph_partitioning.py ├── legion_server.py ├── meta_config ├── prepare_dataset.sh ├── sampling_server ├── Makefile ├── sampling_server.cpp ├── setup.py └── src │ ├── cache │ ├── cache.cu │ ├── cache.cuh │ └── cache_impl.cuh │ ├── engine │ ├── helper_multiprocess.cu │ ├── helper_multiprocess.h │ ├── ipc_service.cu │ ├── ipc_service.h │ ├── memorypool.cu │ ├── memorypool.cuh │ ├── monitor.cuh │ ├── operator.cu │ ├── operator.h │ ├── operator_impl.cu │ ├── operator_impl.cuh │ ├── server.cu │ ├── server.h │ └── server_imp.cuh │ ├── include │ ├── buildinfo.h │ ├── hashmap.h │ ├── hashmap │ │ ├── CMakeLists.txt │ │ ├── bcht.hpp │ │ ├── benchmark_helpers.cuh │ │ ├── cht.hpp │ │ ├── cmd.hpp │ │ ├── detail │ │ │ ├── allocator.hpp │ │ │ ├── bcht_impl.cuh │ │ │ ├── benchmark_metrics.cuh │ │ │ ├── bucket.cuh │ │ │ ├── cht_impl.cuh │ │ │ ├── cuda_helpers.cuh │ │ │ ├── hash_functions.cuh │ │ │ ├── iht_impl.cuh │ │ │ ├── kernels.cuh │ │ │ ├── p2bht_impl.cuh │ │ │ ├── pair.cuh │ │ │ ├── pair_detail.hpp │ │ │ ├── ptx.cuh │ │ │ └── rng.hpp │ │ ├── genzipf.hpp │ │ ├── gpu_timer.hpp │ │ ├── iht.hpp │ │ ├── p2bht.hpp │ │ ├── perf_report.hpp │ │ └── rkg.hpp │ └── system_config.cuh │ ├── main.cu │ └── storage │ ├── feature_storage.cu │ ├── feature_storage.cuh │ ├── feature_storage_impl.cuh │ ├── graph_storage.cu │ ├── graph_storage.cuh │ ├── graph_storage_impl.cuh │ ├── storage_management.cu │ ├── storage_management.cuh │ └── storage_management_impl.cuh └── training_backend ├── helper_multiprocess.cpp ├── helper_multiprocess.h ├── ipc_cuda_kernel.cu ├── ipc_service.cpp ├── ipc_service.h ├── legion_gat.py ├── legion_gcn.py ├── legion_graphsage.py ├── lp_sage.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ukunion 2 | dataset/lib 3 | dataset/xtrapulp 4 | dataset/xtrapulp_result 5 | dataset/gen_legion_xtrapulp_fomat 6 | sampling_server/build 7 | training_backend/build 8 | training_backend/dist 9 | training_backend/ipcservice.egg-info -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Legion is a GPU-initiated system with multi-GPU cache for large-scale GNN training. 2 | ``` 3 | $ git clone https://github.com/RC4ML/Legion.git 4 | ``` 5 | 6 | ## 1. Hardware 7 | ### Hardware Used in Our Paper 8 | All platforms are bare-metal machines. 9 | Table 1 10 | | Platform | CPU-Info | #sockets | #NUMA nodes | CPU Memory | PCIe | GPUs | NVLinks | 11 | | --- | --- | --- | --- | --- | --- | --- | --- | 12 | | DGX-V100 | 96*Intel(R) Xeon(R) Platinum 8163 CPU @2.5GHZ | 2 | 1 | 384GB | PCIe 3.0x16, 4*PCIe switches, each connecting 2 GPUs | 8x16GB-V100 | NVLink Bridges, Kc = 2, Kg = 4 | 13 | | Siton | 104*Intel(R) Xeon(R) Gold 5320 CPU @2.2GHZ | 2 | 2 | 1TB | PCIe 4.0x16, 2*PCIe switches, each connecting 4 GPUs | 8x40GB-A100 | NVLink Bridges, Kc = 4, Kg = 2 | 14 | | DGX-A100 | 128*Intel(R) Xeon(R) Platinum 8369B CPU @2.9GHZ | 2 | 1 | 1TB | PCIe 4.0x16, 4*PCIE switches, each connecting 2 GPUs | 8x80GB-A100 | NVSwitch, Kc = 1, Kg = 8 | 15 | 16 | Kc means the number of groups in which GPUs connect each other. And Kg means the number of GPUs in each group. 17 | 18 | 19 | ## 2. Software 20 | Legion's software is light-weighted and portable. Here we list some tested environment. 21 | 22 | 1. Nvidia Driver Version: 515.43.04 23 | 24 | 2. CUDA 11.7 25 | 26 | 3. GCC/G++ 11.4.0 27 | 28 | 4. OS: Ubuntu(other linux systems are ok) 29 | 30 | 5. Intel PCM(according to OS version) 31 | ``` 32 | $ wget https://download.opensuse.org/repositories/home:/opcm/xUbuntu_18.04/amd64/pcm_0-0+651.1_amd64.deb 33 | ``` 34 | 6. pytorch-cu117, torchmetrics 35 | ``` 36 | $ pip3 install torch-cu1xx 37 | ``` 38 | 7. dgl 1.1.0 39 | ``` 40 | $ pip3 install dgl -f https://data.dgl.ai/wheels/cu1xx/repo.html 41 | ``` 42 | 8. MPI-3.1 43 | 44 | 45 | ## 3. Prepare Datasets and Graph Partitioning 46 | Datasets are from OGB (https://ogb.stanford.edu/), Standford-snap (https://snap.stanford.edu/), and Webgraph (https://webgraph.di.unimi.it/). 47 | Here is an example of preparing datasets for Legion. 48 | 49 | ### Uk-Union Datasets 50 | Refer to README in dataset directory for more instructions 51 | ``` 52 | $ bash prepare_datasets.sh 53 | ``` 54 | 55 | ### Partition Uk-Union 56 | gpu_num represents all gpu numbers you want to use, Legion will partition the graph according to underlying NVlink topology 57 | Note that this step would consume a large volume of CPU memory. 58 | ``` 59 | $ python graph_partitioning.py --dataset_name 'ukunion' --gpu_num 2 60 | ``` 61 | 62 | ## 4. Build Legion from Source 63 | 64 | ``` 65 | $ bash build.sh 66 | ``` 67 | 68 | ## 4. Run Legion 69 | There are three steps to train a GNN model in Legion. In these steps, you need to change to **root** user for PCM. (2024.3.11, to solve PCM bugs for general platforms, I disable PCM for now) 70 | ### Step 1. Open msr by root for PCM 71 | ``` 72 | $ modprobe msr 73 | ``` 74 | ### Step 2. Start Legion Server 75 | 76 | ``` 77 | $ python legion_server.py --dataset_path 'dataset' --dataset_name ukunion --train_batch_size 8000 --fanout [25,10] --gpu_number 2 --epoch 2 --cache_memory 38000000 78 | ``` 79 | 80 | ### Step 3. Run Legion Training 81 | After Legion outputs "System is ready for serving", then start training by: 82 | ``` 83 | $ python training_backend/legion_graphsage.py --class_num 2 --features_num 128 --hidden_dim 256 --hops_num 2 --gpu_number 2 --epoch 2 84 | ``` 85 | I will continusly work on this to improve the running process for easier use. 86 | 87 | ## Cite this work 88 | If you use it in your paper, please cite our work 89 | 90 | ``` 91 | @inproceedings {sun2023legion, 92 | author = {Jie Sun and Li Su and Zuocheng Shi and Wenting Shen and Zeke Wang and Lei Wang and Jie Zhang and Yong Li and Wenyuan Yu and Jingren Zhou and Fei Wu}, 93 | title = {Legion: Automatically Pushing the Envelope of Multi-GPU System for Billion-Scale GNN Training}, 94 | booktitle = {2023 USENIX Annual Technical Conference (USENIX ATC 23)}, 95 | year = {2023}, 96 | pages = {165--179} 97 | } 98 | ``` 99 | 100 | ## Future Features of Legion 101 | We will open-source SSD support for Legion in the future. 102 | 103 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | cd sampling_server && \ 2 | make clean && make -j 8 && \ 3 | cd .. && \ 4 | cd training_backend && \ 5 | python setup.py install && \ 6 | cd .. 7 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # This file elaborates more on dataset preparing 2 | 3 | ## Legion Format 4 | Take uk-union as an example 5 | ``` 6 | Edge: uk-union/edge_src, uk-union/edge_dst ## These are topology files in CSR format. edge_src is int64, edge_dst is int32. 7 | feature: uk-union/features ## size = number of vertices x feature size 8 | label: uk-union/labels ## size = number of vertices 9 | partition: uk-union/partition ## partition_id for each vertex 10 | ``` 11 | ## Customize your datasets 12 | ``` 13 | cd dataset/ 14 | ``` 15 | Creature enviroments for webgraph 16 | ``` 17 | mkdir lib 18 | mv webgraph-3.5.2.jar lib/ 19 | tar -xzvf webgraph-3.6.8-deps.tar.gz -C lib 20 | ``` 21 | Take uk-union for example 22 | ``` 23 | mkdir ukunion 24 | cd ukunion 25 | wget http://data.law.di.unimi.it/webdata/uk-union-2006-06-2007-05/uk-union-2006-06-2007-05-underlying.graph 26 | wget http://data.law.di.unimi.it/webdata/uk-union-2006-06-2007-05/uk-union-2006-06-2007-05-underlying.properties 27 | cd .. 28 | java -cp "lib/*" it.unimi.dsi.webgraph.ArcListASCIIGraph ukunion/uk-union-2006-06-2007-05-underlying ukunion/ukunion-edgelist.txt 29 | 30 | mkdir xtrapulp_result 31 | # generate legion-format edge_src edge_dst, and the input of xtrapulp 32 | g++ gen_legion_xtrapulp_fomat.cpp -o gen_legion_xtrapulp_fomat 33 | ./gen_legion_xtrapulp_fomat ukunion ukunion-edgelist.txt 34 | # generate training sets, validation sets, and test sets 35 | python gen_sets.py --dataset_name ukunion 36 | 37 | ``` 38 | 39 | # 2. Graph partitioning 40 | ## Install MPI 41 | ``` 42 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.0.tar.gz 43 | tar zxf openmpi-3.1.0.tar.gz 44 | cd openmpi-3.1.0 45 | sudo ./configure --prefix=/usr/local/openmp 46 | sudo make 47 | sudo make install 48 | MPI_HOME=/usr/local/openmpi 49 | export PATH=${MPI_HOME}/bin:$PATH 50 | export LD_LIBRARY_PATH=${MPI_HOME}/lib:$LD_LIBRARY_PATH 51 | export MANPATH=${MPI_HOME}/share/man:$MANPATH 52 | 53 | # or the instructions in the following 54 | # sudo apt-get install openmpi-bin openmpi-doc libopenmpi-dev 55 | # sudo apt-get install mpich libmpich-dev 56 | 57 | ``` 58 | ## install xtrapulp, refer to https://github.com/luoxiaojian/xtrapulp 59 | ``` 60 | git clone https://github.com/luoxiaojian/xtrapulp.git 61 | mv ukunion_xtraformat xtrapulp/ 62 | cd xtrapulp 63 | make 64 | make libxtrapulp 65 | cd ../../ 66 | ``` -------------------------------------------------------------------------------- /dataset/convert_to_bin.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # 假设你的文件名是 'data.txt' 4 | file_path = 'data.txt' 5 | 6 | # 使用pandas的read_csv函数读取文件 7 | df = pd.read_csv(file_path, header=None, delimiter="\s+") 8 | 9 | # 转换为numpy数组 10 | data = df.to_numpy() 11 | 12 | print(data) 13 | -------------------------------------------------------------------------------- /dataset/gen_legion_xtrapulp_fomat.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | using namespace std; 13 | 14 | int64_t node_nums = 133633040; 15 | int64_t edge_nums = 5507679822; 16 | 17 | #define ten_e 1000000000 18 | #define one_e 100000000 19 | 20 | vector> vc; 21 | vector indptr; 22 | vector indices; 23 | vector edges; 24 | 25 | int64_t lines = 0; 26 | int cnt = 1; 27 | int reid = 0; 28 | int64_t qq = one_e; 29 | int64_t gl_nums = 0; 30 | bool flag[300000000]; 31 | int32_t* idmap; 32 | 33 | void read(std::string &file_path) { 34 | int fd = open(file_path.c_str(), O_RDONLY); 35 | int64_t buf_size = lseek(fd, 0, SEEK_END); 36 | char* buf = (char*)mmap(NULL, buf_size, PROT_READ, MAP_PRIVATE, fd, 0); 37 | const char* buf_end = buf + buf_size; 38 | int src, dst; 39 | std::string str = ""; 40 | while(buf < buf_end) { 41 | if(*buf == '\t'){ 42 | src = stoi(str); 43 | str = ""; 44 | buf++; 45 | continue; 46 | } 47 | if (*buf == '\n') { 48 | lines++; 49 | if(lines%qq==0){ 50 | cout<<"loaded "< " << std::endl; 113 | return 1; 114 | } 115 | std::string file_dir = argv[1]; 116 | std::string file_name = argv[2]; 117 | std::string file_path = file_dir + "/" + file_name; 118 | // string file_path = "ukunion/ukunion-edgelist.txt"; 119 | 120 | memset(flag,false,sizeof(flag)); 121 | idmap = (int32_t*)malloc(int64_t(int64_t(300000000) * sizeof(int32_t))); 122 | 123 | cout<<"edges loading:"< 1: 109 | partition_command = [ 110 | "mpirun", 111 | "-n", "4", 112 | "./dataset/xtrapulp/xtrapulp", 113 | "./dataset/xtrapulp/" + args.dataset_name+"_xtraformat", 114 | str(int(args.gpu_num/group_size)), 115 | "-v", "1.03", 116 | "-l" 117 | ] 118 | print(partition_command) 119 | result = subprocess.run(partition_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 120 | 121 | print("STDOUT:", result.stdout) 122 | print("STDERR:", result.stderr) 123 | 124 | file_path = "./dataset/xtrapulp/"+args.dataset_name+"_xtraformat.parts."+str(int(args.gpu_num/group_size)) 125 | df = pd.read_csv(file_path, header=None, delimiter="\s+") 126 | data = df.to_numpy() 127 | data = data.astype(np.int32) 128 | data.tofile('partition') 129 | # print(data) 130 | 131 | move_command = [ 132 | "mv", 133 | "partition", 134 | "./dataset/"+args.dataset_name+"/" 135 | ] 136 | print(move_command) 137 | result2 = subprocess.run(move_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 138 | print("STDERR:", result2.stderr) 139 | -------------------------------------------------------------------------------- /legion_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | import re 5 | import networkx as nx 6 | import math 7 | 8 | def parse_topo_output(output): 9 | """ 10 | Parses the output from `nvidia-smi topo -m` to extract the NVLink connections between GPUs. 11 | This function is adjusted based on your provided example output. 12 | """ 13 | connections = [] 14 | lines = output.splitlines() 15 | gpu_lines = [line for line in lines if line.startswith("GPU")] 16 | for i, line in enumerate(gpu_lines): 17 | elements = line.split() 18 | for j, elem in enumerate(elements[1:], start=0): # Start from the first GPU column 19 | if elem.startswith("NV"): 20 | connections.append((i, j)) 21 | return connections 22 | 23 | def get_nvlink_topology(): 24 | # Execute the `nvidia-smi topo -m` command to get the topology matrix 25 | result = subprocess.run(['nvidia-smi', 'topo', '-m'], stdout=subprocess.PIPE, text=True) 26 | connections = parse_topo_output(result.stdout) 27 | return connections 28 | 29 | def find_largest_fully_connected_group(G): 30 | """ 31 | Finds the largest fully connected group (clique) in the graph and returns its size 32 | and the list of such groups if there are multiple of the same size. 33 | """ 34 | cliques = list(nx.find_cliques(G)) 35 | max_size = max(len(clique) for clique in cliques) if cliques else 1 36 | max_cliques = [clique for clique in cliques if len(clique) == max_size] 37 | return max_size, max_cliques 38 | 39 | def Run(args): 40 | 41 | if args.dataset_name == "products": 42 | path = args.dataset_path + "/products/" 43 | vertices_num = 2449029 44 | edges_num = 123718280 45 | features_dim = 100 46 | train_set_num = 196615 47 | valid_set_num = 39323 48 | test_set_num = 2213091 49 | elif args.dataset_name == "paper100m": 50 | path = args.dataset_path + "/paper100M/" 51 | vertices_num = 111059956 52 | edges_num = 1615685872 53 | features_dim = 128 54 | train_set_num = 11105995 55 | valid_set_num = 100000 56 | test_set_num = 100000 57 | elif args.dataset_name == "com-friendster": 58 | path = args.dataset_path + "/com-friendster/" 59 | vertices_num = 65608366 60 | edges_num = 1806067135 61 | features_dim = 256 62 | train_set_num = 6560836 63 | valid_set_num = 100000 64 | test_set_num = 100000 65 | elif args.dataset_name == "ukunion": 66 | path = args.dataset_path + "/ukunion/" 67 | vertices_num = 133633040 68 | edges_num = 5507679822 69 | features_dim = 256 70 | train_set_num = 13363304 71 | valid_set_num = 100000 72 | test_set_num = 100000 73 | elif args.dataset_name == "uk2014": 74 | path = args.dataset_path + "/uk2014/" 75 | vertices_num = 787801471 76 | edges_num = 47284178505 77 | features_dim = 128 78 | train_set_num = 78780147 79 | valid_set_num = 100000 80 | test_set_num = 100000 81 | elif args.dataset_name == "clueweb": 82 | path = args.dataset_path + "/clueweb/" 83 | vertices_num = 955207488 84 | edges_num = 42574107469 85 | features_dim = 128 86 | train_set_num = 95520748 87 | valid_set_num = 100000 88 | test_set_num = 100000 89 | else: 90 | print("invalid dataset path") 91 | exit 92 | 93 | 94 | with open("meta_config","w") as file: 95 | file.write("{} {} {} {} {} {} {} {} {} {}".format(path, args.train_batch_size, vertices_num, edges_num, features_dim, train_set_num, valid_set_num, test_set_num, args.cache_memory, args.epoch)) 96 | 97 | gpu_number = args.gpu_number 98 | 99 | if args.usenvlink == 1: 100 | connections = get_nvlink_topology() 101 | G = nx.Graph() 102 | G.add_edges_from(connections) 103 | group_size, fully_connected_groups = find_largest_fully_connected_group(G) 104 | if fully_connected_groups or group_size == 1: 105 | print(f"NVLink clique size: {group_size}, Number of NVLink cliques: {int(gpu_number/group_size)}") 106 | cache_agg_mode = math.log2(group_size) 107 | else: 108 | cache_agg_mode = 0 109 | 110 | os.system("./sampling_server/build/bin/sampling_server {} {}".format(gpu_number, cache_agg_mode)) 111 | ## TODO, integrate Legion server in python module 112 | 113 | 114 | if __name__ == "__main__": 115 | 116 | argparser = argparse.ArgumentParser("Legion Server.") 117 | argparser.add_argument('--dataset_path', type=str, default="./dataset") 118 | argparser.add_argument('--dataset_name', type=str, default="ukunion") 119 | argparser.add_argument('--train_batch_size', type=int, default=8000) 120 | argparser.add_argument('--fanout', type=list, default=[25, 10]) 121 | argparser.add_argument('--gpu_number', type=int, default=2) 122 | argparser.add_argument('--epoch', type=int, default=2) 123 | argparser.add_argument('--cache_memory', type=int, default=38000000) 124 | argparser.add_argument('--usenvlink', type=int, default=1) 125 | args = argparser.parse_args() 126 | 127 | Run(args) 128 | -------------------------------------------------------------------------------- /meta_config: -------------------------------------------------------------------------------- 1 | /share/gnn_data/paper100M/ 8000 111059956 1615685872 128 11105995 100000 100000 38000000 2 -------------------------------------------------------------------------------- /prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | cd dataset/ 2 | # # create webgraph environment 3 | mkdir lib 4 | cp webgraph-3.5.2.jar lib/ 5 | tar -xzvf webgraph-3.6.8-deps.tar.gz -C lib 6 | 7 | mkdir ukunion 8 | cd ukunion 9 | wget http://data.law.di.unimi.it/webdata/uk-union-2006-06-2007-05/uk-union-2006-06-2007-05-underlying.graph 10 | wget http://data.law.di.unimi.it/webdata/uk-union-2006-06-2007-05/uk-union-2006-06-2007-05-underlying.properties 11 | cd .. 12 | java -cp "lib/*" it.unimi.dsi.webgraph.ArcListASCIIGraph ukunion/uk-union-2006-06-2007-05-underlying ukunion/ukunion-edgelist.txt 13 | 14 | mkdir xtrapulp_result 15 | # generate legion-format edge_src edge_dst, and the input of xtrapulp 16 | g++ gen_legion_xtrapulp_fomat.cpp -o gen_legion_xtrapulp_fomat 17 | ./gen_legion_xtrapulp_fomat ukunion ukunion-edgelist.txt 18 | # generate training sets, validation sets, and test sets 19 | python gen_sets.py --dataset_name ukunion 20 | 21 | # 2.If you don't have mpi, install mpi first. 22 | 23 | # wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.0.tar.gz 24 | # tar zxf openmpi-3.1.0.tar.gz 25 | 26 | # cd openmpi-3.1.0 27 | # sudo ./configure --prefix=/usr/local/openmp 28 | # sudo make 29 | # sudo make install 30 | 31 | # MPI_HOME=/usr/local/openmpi 32 | # export PATH=${MPI_HOME}/bin:$PATH 33 | # export LD_LIBRARY_PATH=${MPI_HOME}/lib:$LD_LIBRARY_PATH 34 | # export MANPATH=${MPI_HOME}/share/man:$MANPATH 35 | 36 | # sudo apt-get install openmpi-bin openmpi-doc libopenmpi-dev 37 | # sudo apt-get install mpich libmpich-dev 38 | 39 | 40 | # # install xtrapulp, refer to https://github.com/luoxiaojian/xtrapulp 41 | git clone https://github.com/luoxiaojian/xtrapulp.git 42 | mv ukunion_xtraformat xtrapulp/ 43 | cd xtrapulp 44 | make 45 | make libxtrapulp 46 | cd ../../ 47 | -------------------------------------------------------------------------------- /sampling_server/Makefile: -------------------------------------------------------------------------------- 1 | default: bin 2 | 3 | HERE := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) 4 | ROOT := $(realpath $(HERE)) 5 | BUILD_DIR := $(ROOT)/build 6 | INCLUDE_DIR := $(BUILD_DIR)/include 7 | LIB_DIR := $(BUILD_DIR)/lib 8 | BIN_DIR := $(BUILD_DIR)/bin 9 | 10 | THIRD_PARTY_DIR := $(ROOT)/third_party 11 | PYTHON_DIR := $(ROOT)/python 12 | SRC_DIR = $(ROOT)/src 13 | 14 | 15 | clean: 16 | @rm -rf $(BUILD_DIR) 17 | 18 | ################################## compiler ####################################### 19 | 20 | NVCXX := nvcc 21 | NVCXXFLAGS := -std=c++17 -arch=sm_60 -Xcompiler -Wall \ 22 | --extended-lambda \ 23 | -Xcompiler -fPIC -lpthread \ 24 | --expt-relaxed-constexpr \ 25 | -I $(SRC_DIR)/include \ 26 | -I $(SRC_DIR)/cache \ 27 | -I $(SRC_DIR)/engine \ 28 | -I $(SRC_DIR)/storage \ 29 | -I $(SRC_DIR)/include/hashmap \ 30 | -g 31 | 32 | 33 | CXX := g++ 34 | CXXFLAGS := -std=c++17 -fPIC -O2 \ 35 | -I $(SRC_DIR)/engine -g 36 | 37 | ################################## bin ########################################### 38 | bin: core 39 | @mkdir -p $(BUILD_DIR)/built 40 | @mkdir -p $(INCLUDE_DIR) 41 | @mkdir -p $(LIB_DIR) 42 | @mkdir -p $(BIN_DIR) 43 | @echo $(CORE_H) 44 | $(NVCXX) $(NVCXXFLAGS) \ 45 | -L /usr/local/cuda/lib64 -L $(BUILD_DIR)/lib \ 46 | $(CORE_OBJ) \ 47 | -lcudart -lpthread \ 48 | -o $(BIN_DIR)/sampling_server -lrt 49 | 50 | 51 | ################################## shared ########################################### 52 | shared: core 53 | @mkdir -p $(BUILD_DIR)/built 54 | @mkdir -p $(INCLUDE_DIR) 55 | @mkdir -p $(LIB_DIR) 56 | @mkdir -p $(BIN_DIR) 57 | @echo $(CORE_H) 58 | $(NVCXX) $(NVCXXFLAGS) -shared \ 59 | -L /usr/local/cuda/lib64 \ 60 | $(CORE_OBJ) \ 61 | -lcudart -lpthread\ 62 | -o $(LIB_DIR)/libserver.so -lrt 63 | 64 | ################################## core ########################################### 65 | CORE_DIR := $(SRC_DIR) 66 | CORE_BUILT_DIR := $(BUILD_DIR)/built 67 | CORE_DIRS := $(shell find "src" -maxdepth 3 -type d) 68 | CORE_H := $(foreach dir,$(CORE_DIRS),$(wildcard $(dir)/*.h)) 69 | CORE_CUH := $(foreach dir,$(CORE_DIRS),$(wildcard $(dir)/*.cuh)) 70 | CORE_CU := $(foreach dir,$(CORE_DIRS),$(wildcard $(dir)/*.cu)) 71 | CORE_OBJ := $(addprefix $(CORE_BUILT_DIR)/,$(patsubst %.cu,%.o,$(CORE_CU))) 72 | 73 | $(CORE_BUILT_DIR)/src/%.o:$(CORE_DIR)/%.cu $(CORE_H) $(CORE_CUH) 74 | @mkdir -p $(CORE_BUILT_DIR) 75 | @mkdir -p $(CORE_BUILT_DIR)/src 76 | $(NVCXX) $(NVCXXFLAGS) -c $< -o $@ -lrt 77 | 78 | $(CORE_BUILT_DIR)/src/cache/%.o:$(CORE_DIR)/cache/%.cu $(CORE_H) $(CORE_CUH) 79 | @mkdir -p $(CORE_BUILT_DIR)/src/cache 80 | @echo $< 81 | $(NVCXX) $(NVCXXFLAGS) -c $< -o $@ -lrt 82 | 83 | $(CORE_BUILT_DIR)/src/storage/%.o:$(CORE_DIR)/storage/%.cu $(CORE_H) $(CORE_CUH) 84 | @mkdir -p $(CORE_BUILT_DIR)/src/storage 85 | $(NVCXX) $(NVCXXFLAGS) -c $< -o $@ -lrt 86 | 87 | $(CORE_BUILT_DIR)/src/engine/%.o:$(CORE_DIR)/engine/%.cu $(CORE_H) $(CORE_CUH) 88 | @mkdir -p $(CORE_BUILT_DIR)/src/engine 89 | $(NVCXX) $(NVCXXFLAGS) -c $< -o $@ -lrt 90 | 91 | core: $(CORE_OBJ) 92 | 93 | -------------------------------------------------------------------------------- /sampling_server/sampling_server.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "server.h" 4 | #include 5 | #include 6 | 7 | int Run(const std::vector& fanout, int gpu_number, int in_memory_mode, int cache_mode){ 8 | std::cout<<"Start Sampling Server\n"; 9 | Server* server = NewGPUServer(); 10 | server->Initialize(gpu_number, fanout, in_memory_mode);//gpu number, default 1; in memory, default true 11 | server->PreSc(cache_mode);//cache aggregate mode, default 0 12 | server->Run(); 13 | server->Finalize(); 14 | return 0; 15 | } 16 | 17 | namespace py = pybind11; 18 | 19 | PYBIND11_MODULE(sampling_server, m) { 20 | m.doc() = "pybind11 plugin"; 21 | m.def("Run", &Run, "Run Sampling Server"); 22 | } 23 | -------------------------------------------------------------------------------- /sampling_server/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | import pybind11 3 | 4 | # 定义额外的链接器参数 5 | extra_link_args = ['-L build/lib', '-lserver'] # 假设动态库不在标准路径 6 | custom_include_path = 'src/engine/' 7 | 8 | sampling_module = Extension( 9 | 'sampling_server', # 模块名 10 | sources=['sampling_server.cpp'], 11 | include_dirs=[ 12 | custom_include_path, 13 | pybind11.get_include() # 添加pybind11的头文件路径 14 | ], 15 | language='c++', 16 | extra_compile_args=['-std=c++11'], # 添加C++编译参数 17 | extra_link_args=extra_link_args, # 添加链接器参数 18 | ) 19 | 20 | setup( 21 | name='sampling_server', 22 | version='1.0', 23 | description='Python package with C++ extension for Sampling Server', 24 | ext_modules=[sampling_module], 25 | ) 26 | 27 | -------------------------------------------------------------------------------- /sampling_server/src/cache/cache.cuh: -------------------------------------------------------------------------------- 1 | #ifndef CACHE_H 2 | #define CACHE_H 3 | 4 | #include "graph_storage.cuh" 5 | #include "feature_storage.cuh" 6 | 7 | #include 8 | #include 9 | 10 | class CacheController{ 11 | public: 12 | virtual ~CacheController() = default; 13 | 14 | virtual void Initialize( 15 | int32_t dev_id, 16 | int32_t total_num_nodes) = 0; 17 | 18 | virtual void Finalize() = 0; 19 | 20 | virtual void FindFeat( 21 | int32_t* sampled_ids, 22 | int32_t* cache_offset, 23 | int32_t* node_counter, 24 | int32_t op_id, 25 | void* stream) = 0; 26 | 27 | virtual void FindTopo(int32_t* input_ids, 28 | char* partition_index, 29 | int32_t* partition_offset, 30 | int32_t batch_size, 31 | int32_t op_id, 32 | void* strm_hdl, 33 | int32_t device_id) = 0; 34 | 35 | virtual void CacheProfiling( 36 | int32_t* sampled_ids, 37 | int32_t* agg_src_id, 38 | int32_t* agg_dst_id, 39 | int32_t* agg_src_off, 40 | int32_t* agg_dst_off, 41 | int32_t* node_counter, 42 | int32_t* edge_counter, 43 | bool is_presc, 44 | void* stream) = 0; 45 | 46 | virtual void InitializeMap(int node_capacity, int edge_capacity) = 0; 47 | 48 | virtual void Insert(int32_t* QT, int32_t* QF, int32_t cache_expand, int32_t Kg) = 0; 49 | 50 | virtual void HybridInsert(int32_t* QF, int32_t cpu_cache_capacity, int32_t gpu_cache_capacity) = 0; 51 | 52 | virtual void AccessCount( 53 | int32_t* d_key, 54 | int32_t num_keys, 55 | void* stream) = 0; 56 | 57 | virtual unsigned long long int* GetNodeAccessedMap() = 0; 58 | 59 | virtual unsigned long long int* GetEdgeAccessedMap() = 0; 60 | 61 | virtual int32_t MaxIdNum() = 0; 62 | }; 63 | 64 | CacheController* NewPreSCCacheController(int32_t train_step, int32_t device_count); 65 | 66 | class UnifiedCache{ 67 | public: 68 | void Initialize( 69 | int64_t cache_memory, 70 | int32_t float_feature_len, 71 | int32_t train_step, 72 | int32_t device_count, 73 | int32_t cpu_cache_capacity, 74 | int32_t gpu_cache_capacity); 75 | 76 | void InitializeCacheController( 77 | int32_t dev_id, 78 | int32_t total_num_nodes); 79 | 80 | void Finalize(int32_t dev_id); 81 | 82 | //these api will change, find, update, clear 83 | void FindFeat( 84 | int32_t* sampled_ids, 85 | int32_t* cache_offset, 86 | int32_t* node_counter, 87 | int32_t op_id, 88 | void* stream, 89 | int32_t dev_id); 90 | 91 | void FindTopo( 92 | int32_t* input_ids, 93 | char* partition_index, 94 | int32_t* partition_offset, 95 | int32_t batch_size, 96 | int32_t op_id, 97 | void* strm_hdl, 98 | int32_t dev_id); 99 | 100 | void CacheProfiling( 101 | int32_t* sampled_ids, 102 | int32_t* agg_src_id, 103 | int32_t* agg_dst_id, 104 | int32_t* agg_src_off, 105 | int32_t* agg_dst_off, 106 | int32_t* node_counter, 107 | int32_t* edge_counter, 108 | void* stream, 109 | int32_t dev_id); 110 | 111 | void AccessCount( 112 | int32_t* d_key, 113 | int32_t num_keys, 114 | void* stream, 115 | int32_t dev_id); 116 | 117 | void CandidateSelection(int cache_agg_mode, FeatureStorage* feature, GraphStorage* graph); 118 | 119 | void CostModel(int cache_agg_mode, FeatureStorage* feature, GraphStorage* graph, std::vector& counters, int32_t train_step); 120 | 121 | void FillUp(int cache_agg_mode, FeatureStorage* feature, GraphStorage* graph); 122 | 123 | void HybridInit(FeatureStorage* feature, GraphStorage* graph); 124 | 125 | int32_t MaxIdNum(int32_t dev_id); 126 | 127 | unsigned long long int* GetEdgeAccessedMap(int32_t dev_id); 128 | 129 | void FeatCacheLookup(int32_t* sampled_ids, int32_t* cache_index, 130 | int32_t* node_counter, float* dst_float_buffer, 131 | int32_t op_id, int32_t dev_id, cudaStream_t strm_hdl); 132 | 133 | private: 134 | int32_t NodeCapacity(int32_t dev_id); 135 | 136 | int32_t CPUCapacity(); 137 | 138 | int32_t GPUCapacity(); 139 | 140 | float* Float_Feature_Cache(int32_t dev_id);//return all features 141 | 142 | float** Global_Float_Feature_Cache(int32_t dev_id); 143 | 144 | std::vector dev_ids_;/*valid device, indexed by device id, False means invalid, True means valid*/ 145 | 146 | int32_t device_count_; 147 | 148 | std::vector cache_controller_; 149 | 150 | std::vector QF_; 151 | std::vector QT_; 152 | std::vector GF_; 153 | std::vector GT_; 154 | std::vector AF_; 155 | std::vector AT_; 156 | int Kc_; 157 | int Kg_; 158 | 159 | std::vector node_capacity_; 160 | std::vector edge_capacity_; 161 | 162 | int32_t cpu_cache_capacity_;//for legion ssd 163 | int32_t gpu_cache_capacity_;//for legion ssd 164 | 165 | int64_t cache_memory_; 166 | std::vector sidx_; 167 | 168 | std::vector int_feature_cache_; 169 | std::vector float_feature_cache_; 170 | std::vector d_float_feature_cache_ptr_; 171 | 172 | int32_t float_feature_len_; 173 | int32_t total_num_nodes_; 174 | float* cpu_float_features_; 175 | 176 | bool is_presc_; 177 | }; 178 | 179 | 180 | 181 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/helper_multiprocess.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | * 3 | * Redistribution and use in source and binary forms, with or without 4 | * modification, are permitted provided that the following conditions 5 | * are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of NVIDIA CORPORATION nor the names of its 12 | * contributors may be used to endorse or promote products derived 13 | * from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef HELPER_MULTIPROCESS_H 29 | #define HELPER_MULTIPROCESS_H 30 | 31 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 32 | #ifndef WIN32_LEAN_AND_MEAN 33 | #define WIN32_LEAN_AND_MEAN 34 | #endif 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | #include 43 | #else 44 | #include 45 | #include 46 | #include 47 | #include 48 | #include 49 | #include 50 | #include 51 | #include 52 | #include 53 | #include 54 | #endif 55 | #include 56 | 57 | typedef struct sharedMemoryInfo_st { 58 | void *addr; 59 | size_t size; 60 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 61 | HANDLE shmHandle; 62 | #else 63 | int shmFd; 64 | #endif 65 | } sharedMemoryInfo; 66 | 67 | int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info); 68 | 69 | int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info); 70 | 71 | void sharedMemoryClose(sharedMemoryInfo *info); 72 | 73 | 74 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 75 | typedef PROCESS_INFORMATION Process; 76 | #else 77 | typedef pid_t Process; 78 | #endif 79 | 80 | int spawnProcess(Process *process, const char *app, char * const *args); 81 | 82 | int waitProcess(Process *process); 83 | 84 | #define checkIpcErrors(ipcFuncResult) \ 85 | if (ipcFuncResult == -1) { fprintf(stderr, "Failure at %u %s\n", __LINE__, __FILE__); exit(EXIT_FAILURE); } 86 | 87 | #if defined(__linux__) 88 | struct ipcHandle_st { 89 | int socket; 90 | char *socketName; 91 | }; 92 | typedef int ShareableHandle; 93 | #elif defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 94 | struct ipcHandle_st { 95 | std::vector hMailslot; // 1 Handle in case of child and `num children` Handles for parent. 96 | }; 97 | typedef HANDLE ShareableHandle; 98 | #endif 99 | 100 | typedef struct ipcHandle_st ipcHandle; 101 | 102 | int 103 | ipcCreateSocket(ipcHandle *&handle, const char *name, const std::vector& processes); 104 | 105 | int 106 | ipcOpenSocket(ipcHandle *&handle); 107 | 108 | int 109 | ipcCloseSocket(ipcHandle *handle); 110 | 111 | int 112 | ipcRecvShareableHandles(ipcHandle *handle, std::vector& shareableHandles); 113 | 114 | int 115 | ipcSendShareableHandles(ipcHandle *handle, const std::vector& shareableHandles, const std::vector& processes); 116 | 117 | int 118 | ipcCloseShareableHandle(ShareableHandle shHandle); 119 | 120 | #endif // HELPER_MULTIPROCESS_H 121 | -------------------------------------------------------------------------------- /sampling_server/src/engine/ipc_service.h: -------------------------------------------------------------------------------- 1 | #ifndef IPC_ENV_H 2 | #define IPC_ENV_H 3 | #include 4 | #include "buildinfo.h" 5 | 6 | class IPCEnv { 7 | public: 8 | virtual void Coordinate(BuildInfo* info) = 0; 9 | virtual int32_t GetMaxStep() = 0; 10 | 11 | virtual void InitializeSamplesBuffer(int32_t batch_size, int32_t num_ids, int32_t feature_dim, int32_t device_id, int32_t pipeline_depth) = 0; 12 | virtual void InitializeFeaturesBuffer(int32_t batch_size, int32_t num_ids, int32_t feature_dim, int32_t device_id, int32_t pipeline_depth) = 0; 13 | 14 | virtual int32_t GetRawBatchsize() = 0; 15 | virtual int32_t GetLocalBatchId(int32_t global_batch_id) = 0; 16 | virtual int32_t GetCurrentBatchsize(int32_t dev_id, int32_t current_mode) = 0; 17 | virtual int32_t GetCurrentMode(int32_t global_batch_id) = 0; 18 | 19 | virtual int32_t* GetIds(int32_t dev_id, int32_t current_pipe) = 0; 20 | virtual float* GetFloatFeatures(int32_t dev_id, int32_t current_pipe) = 0; 21 | virtual int32_t* GetLabels(int32_t dev_id, int32_t current_pipe) = 0; 22 | virtual int32_t* GetAggSrc(int32_t dev_id, int32_t current_pipe) = 0; 23 | virtual int32_t* GetAggDst(int32_t dev_id, int32_t current_pipe) = 0; 24 | virtual int32_t* GetNodeCounter(int32_t dev_id, int32_t current_pipe) = 0; 25 | virtual int32_t* GetEdgeCounter(int32_t dev_id, int32_t current_pipe) = 0; 26 | 27 | virtual void IPCPost(int32_t dev_id, int32_t current_pipe) = 0; 28 | virtual void IPCWait(int32_t dev_id, int32_t current_pipe) = 0; 29 | 30 | virtual void Finalize() = 0; 31 | virtual int32_t GetTrainStep() = 0; 32 | 33 | }; 34 | 35 | IPCEnv* NewIPCEnv(int32_t device_count); 36 | 37 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/memorypool.cu: -------------------------------------------------------------------------------- 1 | #include "memorypool.cuh" 2 | 3 | MemoryPool::MemoryPool(int32_t pipeline_depth){ 4 | pipeline_depth_ = pipeline_depth; 5 | current_pipe_ = 0; 6 | sampled_ids_.resize(pipeline_depth_); 7 | labels_.resize(pipeline_depth_); 8 | float_features_.resize(pipeline_depth_); 9 | agg_dst_off_.resize(pipeline_depth_); 10 | agg_src_off_.resize(pipeline_depth_); 11 | node_counter_.resize(pipeline_depth_); 12 | edge_counter_.resize(pipeline_depth_); 13 | } 14 | -------------------------------------------------------------------------------- /sampling_server/src/engine/memorypool.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef GPU_MEMORY_POOL 3 | #define GPU_MEMORY_POOL 4 | #include 5 | #include 6 | #include 7 | 8 | // Macro for checking cuda errors following a cuda launch or api call 9 | #define cudaCheckError() \ 10 | { \ 11 | cudaError_t e = cudaGetLastError(); \ 12 | if (e != cudaSuccess) { \ 13 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 14 | cudaGetErrorString(e)); \ 15 | exit(EXIT_FAILURE); \ 16 | } \ 17 | } 18 | 19 | 20 | class MemoryPool { 21 | public: 22 | MemoryPool(int32_t pipeline_depth); 23 | 24 | int32_t GetOpId(){//only used by sampler 25 | return op_id_; 26 | } 27 | 28 | int32_t GetIter(){ 29 | return iter_; 30 | } 31 | 32 | int32_t GetCurrentMode(){ 33 | return mode_; 34 | } 35 | 36 | float* GetFloatFeatures(){ 37 | return float_features_[current_pipe_]; 38 | } 39 | 40 | int32_t* GetCacheSearchBuffer(){ 41 | return cache_search_buffer_; 42 | } 43 | 44 | int32_t* GetLabels(){ 45 | return labels_[current_pipe_]; 46 | } 47 | 48 | uint32_t* GetAccessedMap(){ 49 | return accessed_map_; 50 | } 51 | 52 | int32_t* GetPositionMap(){ 53 | return position_map_; 54 | } 55 | 56 | int32_t* GetNodeCounter(){ 57 | return node_counter_[current_pipe_]; 58 | } 59 | 60 | int32_t* GetEdgeCounter(){ 61 | return edge_counter_[current_pipe_]; 62 | } 63 | 64 | int32_t* GetSampledIds(){ 65 | return sampled_ids_[current_pipe_]; 66 | } 67 | 68 | int32_t* GetAggSrcId(){ 69 | return agg_src_ids_; 70 | } 71 | 72 | int32_t* GetAggDstId(){ 73 | return agg_dst_ids_; 74 | } 75 | 76 | int32_t* GetAggSrcOf(){ 77 | return agg_src_off_[current_pipe_]; 78 | } 79 | 80 | int32_t* GetAggDstOf(){ 81 | return agg_dst_off_[current_pipe_]; 82 | } 83 | 84 | int32_t* GetTmpSrcOf(){ 85 | return tmp_src_of_; 86 | } 87 | 88 | int32_t* GetTmpDstOf(){ 89 | return tmp_dst_of_; 90 | } 91 | 92 | void* GetTempStorage(){ 93 | return temp_storage_; 94 | } 95 | 96 | char* GetTmpPartIdx(){ 97 | return tmp_part_ind_; 98 | } 99 | 100 | int32_t* GetTmpPartOff(){ 101 | return tmp_part_off_; 102 | } 103 | 104 | 105 | void SetFloatFeatures(float* float_features, int32_t current_pipe){ 106 | float_features_[current_pipe] = float_features; 107 | } 108 | 109 | void SetCacheSearchBuffer(int32_t* cache_search_buffer){ 110 | cache_search_buffer_ = cache_search_buffer; 111 | } 112 | 113 | void SetLabels(int32_t* labels, int32_t current_pipe){ 114 | labels_[current_pipe] = labels; 115 | } 116 | 117 | void SetAccessedMap(uint32_t* accessed_map){ 118 | accessed_map_ = accessed_map; 119 | } 120 | 121 | void SetPositionMap(int32_t* position_map){ 122 | position_map_ = position_map; 123 | } 124 | 125 | void SetNodeCounter(int32_t* node_counter, int32_t current_pipe){ 126 | node_counter_[current_pipe] = node_counter; 127 | } 128 | 129 | void SetEdgeCounter(int32_t* edge_counter, int32_t current_pipe){ 130 | edge_counter_[current_pipe] = edge_counter; 131 | } 132 | 133 | void SetSampledIds(int32_t* sampled_ids, int32_t current_pipe){ 134 | sampled_ids_[current_pipe] = sampled_ids; 135 | } 136 | 137 | void SetAggSrcId(int32_t* agg_src_ids){ 138 | agg_src_ids_ = agg_src_ids; 139 | } 140 | 141 | void SetAggDstId(int32_t* agg_dst_ids){ 142 | agg_dst_ids_ = agg_dst_ids; 143 | } 144 | 145 | void SetAggSrcOf(int32_t* agg_src_off, int32_t current_pipe){ 146 | agg_src_off_[current_pipe] = agg_src_off; 147 | } 148 | 149 | void SetAggDstOf(int32_t* agg_dst_off, int32_t current_pipe){ 150 | agg_dst_off_[current_pipe] = agg_dst_off; 151 | } 152 | 153 | void SetTmpSrcOf(int32_t* tmp_src_of){ 154 | tmp_src_of_ = tmp_src_of; 155 | } 156 | 157 | void SetTmpDstOf(int32_t* tmp_dst_of){ 158 | tmp_dst_of_ = tmp_dst_of; 159 | } 160 | 161 | void SetTempStorage(void* temp_storage){ 162 | temp_storage_ = temp_storage; 163 | } 164 | 165 | void SetTmpPartIdx(char* tmp_part_ind){ 166 | tmp_part_ind_ = tmp_part_ind; 167 | } 168 | 169 | void SetTmpPartOff(int32_t* tmp_part_off){ 170 | tmp_part_off_ = tmp_part_off; 171 | } 172 | 173 | void SetOpId(int32_t op_id){ 174 | op_id_ = op_id; 175 | } 176 | 177 | void SetCurrentPipe(int32_t current_pipe){ 178 | current_pipe_ = current_pipe; 179 | } 180 | 181 | void SetCurrentMode(int32_t mode){ 182 | mode_ = mode; 183 | } 184 | 185 | void SetIter(int32_t iter){ 186 | iter_ = iter; 187 | } 188 | 189 | void Finalize() { 190 | cudaFree(cache_search_buffer_); 191 | cudaFree(accessed_map_); 192 | cudaFree(position_map_); 193 | cudaFree(agg_src_ids_); 194 | cudaFree(agg_dst_ids_); 195 | } 196 | 197 | private: 198 | int32_t iter_; 199 | int32_t mode_; 200 | int32_t op_id_; 201 | int32_t* cache_search_buffer_; 202 | uint32_t* accessed_map_; 203 | int32_t* position_map_; 204 | int32_t* agg_src_ids_; 205 | int32_t* agg_dst_ids_; 206 | int32_t* tmp_src_of_; 207 | int32_t* tmp_dst_of_; 208 | void* temp_storage_; 209 | char* tmp_part_ind_; 210 | int32_t* tmp_part_off_; 211 | 212 | int32_t pipeline_depth_; 213 | int32_t current_pipe_; 214 | std::vector float_features_; 215 | std::vector labels_; 216 | std::vector node_counter_; 217 | std::vector edge_counter_; 218 | std::vector sampled_ids_; 219 | std::vector agg_src_off_; 220 | std::vector agg_dst_off_; 221 | }; 222 | 223 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/monitor.cuh: -------------------------------------------------------------------------------- 1 | #ifndef MONITOR_H 2 | #define MONITOR_H 3 | 4 | #include 5 | #ifdef _MSC_VER 6 | #include 7 | #include "windows/windriver.h" 8 | #else 9 | #include 10 | #include 11 | #endif 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | // #include 19 | #include 20 | #include "src/pcm-pcie.h" 21 | // #include "cpucounters.h" 22 | 23 | #define PCM_DELAY_DEFAULT 1.0 // in seconds 24 | #define PCM_DELAY_MIN 0.015 // 15 milliseconds is practical on most modern CPUs 25 | 26 | using namespace std; 27 | 28 | bool events_printed = false; 29 | 30 | // #include "zerocp.h" 31 | 32 | IPlatform *IPlatform::getPlatform(PCM *m, bool csv, bool print_bandwidth, bool print_additional_info, uint32 delay) 33 | { 34 | switch (m->getCPUModel()) { 35 | case PCM::SPR: 36 | case PCM::EMR: 37 | std::cout<<"EagleStream"<getCPUModel()) { 65 | // case PCM::ICX: 66 | // case PCM::SNOWRIDGE: 67 | // return new WhitleyPlatform(m, csv, print_bandwidth, print_additional_info, delay); 68 | // case PCM::SKX: 69 | // return new PurleyPlatform(m, csv, print_bandwidth, print_additional_info, delay); 70 | // case PCM::BDX_DE: 71 | // case PCM::BDX: 72 | // case PCM::KNL: 73 | // case PCM::HASWELLX: 74 | // return new GrantleyPlatform(m, csv, print_bandwidth, print_additional_info, delay); 75 | // case PCM::IVYTOWN: 76 | // case PCM::JAKETOWN: 77 | // return new BromolowPlatform(m, csv, print_bandwidth, print_additional_info, delay); 78 | // default: 79 | // return NULL; 80 | // } 81 | // } 82 | 83 | class PCM_Monitor { 84 | public: 85 | void Init(); 86 | // void Start(); 87 | // void Stop(); 88 | // void Print(); 89 | // std::vector GetCounter(); 90 | 91 | private: 92 | std::vector counter_; 93 | IPlatform* platform_; 94 | vector *eventGroups_; 95 | PCM * m_; 96 | }; 97 | 98 | // #endif 99 | 100 | void PCM_Monitor::Init(){ 101 | 102 | set_signal_handlers(); 103 | 104 | cerr << "\n"; 105 | cerr << " Intel(r) Performance Counter Monitor: PCIe Bandwidth Monitoring Utility \n"; 106 | cerr << " This utility measures PCIe bandwidth in real-time\n"; 107 | cerr << "\n"; 108 | 109 | // double delay = 1.0; 110 | bool csv = false; 111 | bool print_bandwidth = true; 112 | bool print_additional_info = false; 113 | // char * sysCmd = NULL; 114 | // char ** sysArgv = NULL; 115 | MainLoop mainLoop; 116 | 117 | m_= PCM::getInstance(); 118 | std::cout<<"get pcm instance"<getEventGroups(); 132 | 133 | platform_->cleanup(); 134 | MySleepMs(uint(1000)); 135 | } 136 | 137 | // void PCM_Monitor::Start(){ 138 | // printf("Start Count PCIe\n"); 139 | // for (auto &evGroup : *eventGroups_){ 140 | // m_->programPCIeEventGroup(evGroup); 141 | // platform_->getEventGroup(evGroup, 0); 142 | // } 143 | // for (auto &evGroup : *eventGroups_){ 144 | // m_->programPCIeEventGroup(evGroup); 145 | // } 146 | // } 147 | 148 | // void PCM_Monitor::Stop(){ 149 | // for (auto &evGroup : *eventGroups_){ 150 | // platform_->getEventGroup(evGroup, 1); 151 | // } 152 | // printf("Stop Count PCIe\n"); 153 | // } 154 | 155 | // void PCM_Monitor::Print(){ 156 | // platform_->printHeader(); 157 | 158 | // platform_->printEvents(); 159 | 160 | // platform_->printAggregatedEvents(); 161 | // } 162 | 163 | // std::vector PCM_Monitor::GetCounter(){ 164 | 165 | // return platform_->GetCounter(); 166 | // } 167 | 168 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/operator.cu: -------------------------------------------------------------------------------- 1 | #include "operator.h" 2 | #include "operator_impl.cuh" 3 | #include "storage_management.cuh" 4 | #include "graph_storage.cuh" 5 | #include "feature_storage.cuh" 6 | #include "ipc_service.h" 7 | #include "cache.cuh" 8 | #include "memorypool.cuh" 9 | 10 | class BatchGenerateOP : public Operator { 11 | public: 12 | BatchGenerateOP(int op_id){ 13 | op_id_ = op_id; 14 | } 15 | void run(OpParams* params) override { 16 | FeatureStorage* feature = (FeatureStorage*)(params->feature); 17 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 18 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 19 | IPCEnv* env = (IPCEnv*)(params->env); 20 | int32_t device_id = params->device_id; 21 | int32_t mode = memorypool->GetCurrentMode(); 22 | int32_t iter = memorypool->GetIter(); 23 | int32_t batch_size = env->GetCurrentBatchsize(device_id, mode); 24 | bool is_presc = params->is_presc; 25 | int32_t hop_num = params->hop_num; 26 | 27 | BatchGenerate(params->stream, feature, cache, memorypool, batch_size, iter, device_id, device_id, mode, is_presc, hop_num); 28 | cudaEventRecord(((params->event)), ((params->stream))); 29 | cudaCheckError(); 30 | } 31 | private: 32 | int op_id_; 33 | }; 34 | 35 | Operator* NewBatchGenerateOP(int op_id){ 36 | return new BatchGenerateOP(op_id); 37 | } 38 | 39 | class RandomSampleOP : public Operator { 40 | public: 41 | RandomSampleOP(int op_id){ 42 | op_id_ = op_id; 43 | } 44 | void run(OpParams* params) override { 45 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 46 | GraphStorage* graph = (GraphStorage*)(params->graph); 47 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 48 | bool is_presc = params->is_presc; 49 | int32_t count = params->neighbor_count; 50 | int32_t device_id = params->device_id; 51 | 52 | RandomSample(params->stream, graph, cache, memorypool, count, device_id, op_id_, is_presc); 53 | cudaEventRecord(((params->event)), ((params->stream))); 54 | cudaCheckError(); 55 | } 56 | private: 57 | int op_id_; 58 | }; 59 | 60 | Operator* NewRandomSampleOP(int op_id){ 61 | return new RandomSampleOP(op_id); 62 | } 63 | 64 | class CacheLookupOP : public Operator { 65 | public: 66 | CacheLookupOP(int op_id){ 67 | op_id_ = op_id; 68 | } 69 | void run(OpParams* params) override { 70 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 71 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 72 | int32_t device_id = params->device_id; 73 | 74 | FeatureCacheLookup(params->stream, cache, memorypool, op_id_, device_id); 75 | cudaEventRecord(((params->event)), ((params->stream))); 76 | cudaCheckError(); 77 | } 78 | private: 79 | int op_id_; 80 | }; 81 | 82 | Operator* NewCacheLookupOP(int op_id){ 83 | return new CacheLookupOP(op_id); 84 | } 85 | 86 | class SSDIOSubmitOP : public Operator { 87 | public: 88 | SSDIOSubmitOP(int op_id) { 89 | op_id_ = op_id; 90 | } 91 | void run(OpParams* params) override { 92 | FeatureStorage* feature = (FeatureStorage*)(params->feature); 93 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 94 | int32_t device_id = params->device_id; 95 | 96 | IOSubmit(params->stream, feature, memorypool, op_id_, device_id); 97 | cudaEventRecord(((params->event)), ((params->stream))); 98 | cudaCheckError(); 99 | } 100 | private: 101 | int op_id_; 102 | }; 103 | 104 | Operator* NewSSDIOSubmitOP(int op_id){ 105 | return new SSDIOSubmitOP(op_id); 106 | } 107 | 108 | class SSDIOCompleteOP : public Operator { 109 | public: 110 | SSDIOCompleteOP(int op_id){ 111 | op_id_ = op_id; 112 | } 113 | void run(OpParams* params) override { 114 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 115 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 116 | int mode = memorypool->GetCurrentMode(); 117 | int32_t device_id = params->device_id; 118 | 119 | IOComplete(params->stream, cache, memorypool, device_id, mode); 120 | cudaEventRecord(((params->event)), ((params->stream))); 121 | cudaCheckError(); 122 | } 123 | private: 124 | int op_id_; 125 | }; 126 | 127 | Operator* NewSSDIOCompleteOP(int op_id){ 128 | return new SSDIOCompleteOP(op_id); 129 | } 130 | 131 | -------------------------------------------------------------------------------- /sampling_server/src/engine/operator.h: -------------------------------------------------------------------------------- 1 | #ifndef OPERATOR_H 2 | #define OPERATOR_H 3 | 4 | struct OpParams { 5 | int device_id; 6 | cudaStream_t stream; 7 | cudaEvent_t event; 8 | void* memorypool; 9 | void* cache; 10 | void* graph; 11 | void* feature; 12 | void* env; 13 | int neighbor_count; 14 | bool is_presc; 15 | bool in_memory; 16 | int hop_num; 17 | }; 18 | 19 | class Operator { 20 | public: 21 | virtual void run(OpParams* params) = 0; 22 | }; 23 | 24 | Operator* NewBatchGenerateOP(int op_id); 25 | Operator* NewRandomSampleOP(int op_id); 26 | Operator* NewCacheLookupOP(int op_id); 27 | Operator* NewSSDIOSubmitOP(int op_id); 28 | Operator* NewSSDIOCompleteOP(int op_id); 29 | 30 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/operator_impl.cuh: -------------------------------------------------------------------------------- 1 | #ifndef OPERATOR_IMPL_H 2 | #define OPERATOR_IMPL_H 3 | #include 4 | #include 5 | 6 | #include "memorypool.cuh" 7 | #include "graph_storage.cuh" 8 | #include "feature_storage.cuh" 9 | #include "cache.cuh" 10 | 11 | extern "C" 12 | void BatchGenerate( 13 | cudaStream_t strm_hdl, 14 | FeatureStorage* feature, 15 | UnifiedCache* cache, 16 | MemoryPool* memorypool, 17 | int32_t batch_size, 18 | int32_t counter, 19 | int32_t part_id, 20 | int32_t dev_id, 21 | int32_t mode, 22 | bool is_presc, 23 | int32_t hop_num 24 | ); 25 | 26 | extern "C" 27 | void RandomSample( 28 | cudaStream_t strm_hdl, 29 | GraphStorage* graph, 30 | UnifiedCache* cache, 31 | MemoryPool* memorypool, 32 | int32_t count, 33 | int32_t dev_id, 34 | int32_t op_id, 35 | bool is_presc 36 | ); 37 | 38 | extern "C" 39 | void FeatureCacheLookup( 40 | cudaStream_t strm_hdl, 41 | UnifiedCache* cache, 42 | MemoryPool* memorypool, 43 | int32_t op_id, 44 | int32_t dev_id 45 | ); 46 | 47 | extern "C" 48 | void IOSubmit( 49 | cudaStream_t strm_hdl, 50 | FeatureStorage* feature, 51 | MemoryPool* memorypool, 52 | int32_t op_id, 53 | int32_t dev_id 54 | ); 55 | 56 | extern "C" 57 | void IOComplete( 58 | cudaStream_t strm_hdl, 59 | UnifiedCache* cache, 60 | MemoryPool* memorypool, 61 | int32_t dev_id, 62 | int32_t mode 63 | ); 64 | 65 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/server.h: -------------------------------------------------------------------------------- 1 | #ifndef SERVER_H 2 | #define SERVER_H 3 | #include 4 | 5 | struct RunnerParams { 6 | int device_id; 7 | std::vector fanout; 8 | void* cache; 9 | void* graph; 10 | void* feature; 11 | void* env; 12 | int global_batch_id; 13 | bool in_memory; 14 | }; 15 | 16 | class Server { 17 | public: 18 | virtual void Initialize(int global_shard_count, std::vector fanout, int in_memory_mode) = 0; 19 | virtual void PreSc(int cache_agg_mode) = 0; 20 | virtual void Run() = 0; 21 | virtual void Finalize() = 0; 22 | }; 23 | Server* NewGPUServer(); 24 | 25 | class Runner { 26 | public: 27 | virtual void Initialize(RunnerParams* params) = 0; 28 | virtual void InitializeFeaturesBuffer(RunnerParams* params) = 0; 29 | virtual void RunPreSc(RunnerParams* params) = 0; 30 | virtual void RunOnce(RunnerParams* params) = 0; 31 | virtual void Finalize(RunnerParams* params) = 0; 32 | }; 33 | Runner* NewGPURunner(); 34 | 35 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/server_imp.cuh: -------------------------------------------------------------------------------- 1 | 2 | extern "C" 3 | void* d_alloc_space(int64_t num_bytes) { 4 | void *ret; 5 | cudaMalloc(&ret, num_bytes); 6 | cudaCheckError(); 7 | return ret; 8 | } 9 | 10 | extern "C" 11 | void* d_alloc_space_managed(unsigned int num_bytes) { 12 | void *ret; 13 | cudaMallocManaged(&ret, num_bytes); 14 | cudaCheckError(); 15 | return ret; 16 | } 17 | 18 | extern "C" 19 | void d_copy_2_h(void* h_ptr, void* d_ptr, unsigned int num_bytes){ 20 | cudaMemcpy(h_ptr, d_ptr, num_bytes, cudaMemcpyDeviceToHost); 21 | cudaCheckError(); 22 | } 23 | 24 | 25 | extern "C" 26 | void SetGPUDevice(int32_t shard_id){ 27 | cudaSetDevice(shard_id); 28 | cudaCheckError(); 29 | } 30 | 31 | extern "C" 32 | int32_t GetGPUDevice(){ 33 | int32_t dev_id; 34 | cudaGetDevice(&dev_id); 35 | return dev_id; 36 | } 37 | 38 | extern "C" 39 | void d_free_space(void* d_ptr){ 40 | cudaFree(d_ptr); 41 | } 42 | 43 | 44 | extern "C" 45 | void* host_alloc_space(unsigned int num_bytes) { 46 | void* host_ptr; 47 | void* ret; 48 | cudaHostAlloc(&host_ptr, num_bytes, cudaHostAllocMapped); 49 | cudaHostGetDevicePointer(&ret, host_ptr, 0); 50 | cudaCheckError(); 51 | return ret; 52 | } -------------------------------------------------------------------------------- /sampling_server/src/include/buildinfo.h: -------------------------------------------------------------------------------- 1 | #ifndef BUILD_INFO_H 2 | #define BUILD_INFO_H 3 | #include 4 | #include 5 | 6 | struct BuildInfo{ 7 | //device `` 8 | std::vector shard_to_partition; 9 | std::vector shard_to_device; 10 | int32_t partition_count; 11 | //training set 12 | std::vector training_set_num; 13 | std::vector> training_set_ids; 14 | std::vector> training_labels; 15 | //validation set 16 | std::vector validation_set_num; 17 | std::vector> validation_set_ids; 18 | std::vector> validation_labels; 19 | //testing set 20 | std::vector testing_set_num; 21 | std::vector> testing_set_ids; 22 | std::vector> testing_labels; 23 | //features 24 | int32_t total_num_nodes; 25 | int32_t float_feature_len; 26 | float* host_float_feature;//allocated by cudaHostAlloc 27 | 28 | //bam params 29 | uint32_t cudaDevice; 30 | uint64_t cudaDeviceId; 31 | const char* blockDevicePath; 32 | const char* controllerPath; 33 | uint64_t controllerId; 34 | uint32_t adapter; 35 | uint32_t segmentId; 36 | uint32_t nvmNamespace; 37 | bool doubleBuffered; 38 | size_t numReqs; 39 | size_t numPages; 40 | size_t startBlock; 41 | bool stats; 42 | const char* output; 43 | size_t numThreads; 44 | uint32_t domain; 45 | uint32_t bus; 46 | uint32_t devfn; 47 | uint32_t n_ctrls; 48 | size_t blkSize; 49 | size_t queueDepth; 50 | size_t numQueues; 51 | size_t pageSize; 52 | uint64_t numElems; 53 | bool random; 54 | uint64_t ssdtype; 55 | 56 | //csr 57 | // std::vector> csr_node_index; 58 | // std::vector> csr_dst_node_ids; 59 | int64_t* csr_node_index; 60 | int32_t* csr_dst_node_ids; 61 | // std::vector partition_index; 62 | // std::vector partition_offset; 63 | int64_t cache_edge_num; 64 | int64_t total_edge_num; 65 | //train 66 | int32_t epoch; 67 | int32_t raw_batch_size; 68 | 69 | //iostack 70 | int32_t num_ssd; 71 | int32_t num_queues_per_ssd; 72 | }; 73 | 74 | #endif -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap.h: -------------------------------------------------------------------------------- 1 | #ifndef HASHMAPH 2 | #define HASHMAPH 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | // #include 9 | // #include 10 | #include 11 | // #include 12 | #include 13 | #include 14 | 15 | #endif -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE bght_src 2 | "*.h" 3 | "*.hpp" 4 | "*.cuh" 5 | "*.cu") 6 | set(SOURCE_LIST ${bght_src}) 7 | target_sources(bght INTERFACE ${bght_src}) 8 | target_include_directories(bght INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}") -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/benchmark_helpers.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | namespace benchmark { 29 | 30 | // min_key and max_key are exclusive 31 | template 32 | void generate_uniform_unique_keys( 33 | std::vector& keys, 34 | std::size_t num_keys, 35 | key_type min_key = std::numeric_limits::min() + 1, 36 | key_type max_key = std::numeric_limits::max() - 1, 37 | unsigned seed = 1, 38 | bool cache = false) { 39 | keys.resize(num_keys); 40 | std::string dataset_dir = "dataset"; 41 | std::string dataset_name = std::to_string(num_keys) + "_" + std::to_string(seed); 42 | std::string dataset_path = dataset_dir + "/" + dataset_name; 43 | if (cache) { 44 | if (std::filesystem::exists(dataset_dir)) { 45 | if (std::filesystem::exists(dataset_path)) { 46 | std::cout << "Reading cached keys.." << std::endl; 47 | std::ifstream dataset(dataset_path, std::ios::binary); 48 | dataset.read((char*)keys.data(), sizeof(key_type) * num_keys); 49 | dataset.close(); 50 | return; 51 | } 52 | } else { 53 | std::filesystem::create_directory(dataset_dir); 54 | } 55 | } 56 | std::random_device rd; 57 | std::mt19937 rng(seed); 58 | std::uniform_int_distribution uni(min_key, max_key); 59 | std::unordered_set unique_keys; 60 | while (unique_keys.size() < num_keys) { 61 | unique_keys.insert(uni(rng)); 62 | } 63 | std::copy(unique_keys.cbegin(), unique_keys.cend(), keys.begin()); 64 | 65 | if (cache) { 66 | std::cout << "Caching.." << std::endl; 67 | std::ofstream dataset(dataset_path, std::ios::binary); 68 | dataset.write((char*)keys.data(), sizeof(key_type) * num_keys); 69 | dataset.close(); 70 | } 71 | } 72 | // 73 | // template 74 | // uint64_t validate(const std::vector& h_keys, 75 | // const std::vector& h_find_keys, 76 | // const thrust::device_vector& d_results, 77 | // const uint32_t& num_keys, 78 | // const value_type& sentinel_value, 79 | // function to_value, 80 | // float exist_ratio = 1.0f) { 81 | // uint64_t num_errors = 0; 82 | // uint64_t max_errors = 10; 83 | // using pair_type = bght::pair_type; 84 | // auto h_results = thrust::host_vector(d_results); 85 | // std::unordered_set cpu_ref_set; 86 | // if (exist_ratio != 1.0f) { 87 | // cpu_ref_set.insert(h_keys.begin(), h_keys.begin() + num_keys); 88 | // } 89 | // for (size_t i = 0; i < num_keys; i++) { 90 | // key_type query_key = h_find_keys[i]; 91 | // value_type query_result = h_results[i]; 92 | // value_type expected_result = to_value(query_key); 93 | // if (exist_ratio != 1.0f) { 94 | // auto expected_result_ptr = cpu_ref_set.find(query_key); 95 | // if (expected_result_ptr == cpu_ref_set.end()) { 96 | // expected_result = sentinel_value; 97 | // } 98 | // } 99 | // 100 | // if (query_result != expected_result) { 101 | // std::string message = std::string("query_key = ") + std::to_string(query_key) + 102 | // std::string(", expected: ") + 103 | // std::to_string(expected_result) + std::string(", found: ") + 104 | // std::to_string(query_result); 105 | // std::cout << message << std::endl; 106 | // num_errors++; 107 | // if (num_errors == max_errors) 108 | // break; 109 | // } 110 | // } 111 | // return num_errors; 112 | //} 113 | 114 | template 115 | void prep_experiment_find_with_exist_ratio(float exist_ratio, 116 | std::size_t num_keys, 117 | const std::vector& keys, 118 | std::vector& find_keys, 119 | key_type* d_find_keys) { 120 | // Choose the keys over which we will search based on the 121 | // exist_ratio. Recall that keys.size() == 2 * num_keys. 122 | assert(num_keys * 2 == keys.size()); 123 | unsigned int end_index = num_keys * (-exist_ratio + 2); 124 | unsigned int start_index = end_index - num_keys; 125 | 126 | static constexpr uint32_t EMPTY_VALUE = 0xFFFFFFFF; 127 | 128 | // Need to copy our range [start_index, end_index) from keys 129 | // into find_keys. 130 | std::fill(find_keys.begin(), find_keys.end(), EMPTY_VALUE); 131 | std::copy(keys.begin() + start_index, keys.begin() + end_index, find_keys.begin()); 132 | cuda_try(cudaMemcpy(d_find_keys, 133 | find_keys.data(), 134 | sizeof(key_type) * find_keys.size(), 135 | cudaMemcpyHostToDevice)); 136 | } 137 | 138 | template 139 | void prep_experiment_find_with_exist_ratio(float exist_ratio, 140 | std::size_t num_keys, 141 | const thrust::device_vector& keys, 142 | thrust::device_vector& find_keys) { 143 | // Choose the keys over which we will search based on the 144 | // exist_ratio. Recall that keys.size() == 2 * num_keys. 145 | assert(num_keys * 2 == keys.size()); 146 | unsigned int end_index = num_keys * (-exist_ratio + 2); 147 | unsigned int start_index = end_index - num_keys; 148 | 149 | static constexpr uint32_t EMPTY_VALUE = 0xFFFFFFFF; 150 | 151 | // Need to copy our range [start_index, end_index) from keys 152 | // into find_keys. 153 | thrust::fill(thrust::device, find_keys.begin(), find_keys.end(), EMPTY_VALUE); 154 | thrust::copy(thrust::device, 155 | keys.begin() + start_index, 156 | keys.begin() + end_index, 157 | find_keys.begin()); 158 | } 159 | 160 | } // namespace benchmark 161 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/cht.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | namespace bght { 28 | 29 | /** 30 | * @brief CHT CHT (cuckoo hash table) is an associative static GPU hash table 31 | * that contains key-value pairs with unique keys. The hash table is an open addressing 32 | * hash table based on the cuckoo hashing probing scheme (bucket size of one and using 33 | * four hash functions). 34 | * 35 | * @tparam Key Type for the hash map key 36 | * @tparam T Type for the mapped value 37 | * @tparam Hash Unary function object class that defines the hash function. The function 38 | * must have an `initialize_hf` specialization to initialize the hash function using a 39 | * random number generator 40 | * @tparam KeyEqual Binary function object class that compares two keys 41 | * @tparam Allocator The allocator to use for allocating GPU device memory 42 | */ 43 | template , 46 | class KeyEqual = bght::equal_to, 47 | cuda::thread_scope Scope = cuda::thread_scope_device, 48 | class Allocator = bght::cuda_allocator> 49 | struct cht { 50 | using value_type = pair; 51 | using key_type = Key; 52 | using mapped_type = T; 53 | using atomic_pair_type = cuda::atomic; 54 | using allocator_type = Allocator; 55 | using hasher = Hash; 56 | using atomic_pair_allocator_type = 57 | typename std::allocator_traits::rebind_alloc; 58 | using pool_allocator_type = 59 | typename std::allocator_traits::rebind_alloc; 60 | using key_equal = KeyEqual; 61 | 62 | /** 63 | * @brief Constructs the hash table with the specified capacity and uses the specified 64 | * sentinel key and value to define a sentinel pair. 65 | * 66 | * @param capacity The number of slots to use in the hash table 67 | * @param sentinel_key A reserved sentinel key that defines an empty key 68 | * @param sentinel_value A reserved sentinel value that defines an empty value 69 | * @param allocator The allocator to use for allocating GPU device memory 70 | */ 71 | cht(std::size_t capacity, 72 | Key sentinel_key, 73 | T sentinel_value, 74 | Allocator const& allocator = Allocator{}); 75 | 76 | /** 77 | * @brief A shallow-copy constructor 78 | */ 79 | cht(const cht& other); 80 | /** 81 | * @brief Move constructor is currently deleted 82 | */ 83 | cht(cht&&) = delete; 84 | /** 85 | * @brief The assignment operator is currently deleted 86 | */ 87 | cht& operator=(const cht&) = delete; 88 | /** 89 | * @brief The move assignment operator is currently deleted 90 | */ 91 | cht& operator=(cht&&) = delete; 92 | /** 93 | * @brief Destructor that destroys the hash map and deallocate memory if no copies exist 94 | */ 95 | ~cht(); 96 | /** 97 | * @brief Clears the hash map and resets all slots 98 | */ 99 | void clear(); 100 | 101 | /** 102 | * @brief Host-side API for inserting all pairs defined by the input argument iterators. 103 | * All keys in the range must be unique and must not exist in the hash table. 104 | * @tparam InputIt Device-side iterator that can be converted to `value_type`. 105 | * @param first An iterator defining the beginning of the input pairs to insert 106 | * @param last An iterator defining the end of the input pairs to insert 107 | * @param stream A CUDA stream where the insertion operation will take place 108 | * @return A boolean indicating success (true) or failure (false) of the insertion 109 | * operation. 110 | */ 111 | template 112 | bool insert(InputIt first, InputIt last, cudaStream_t stream = 0); 113 | 114 | /** 115 | * @brief Host-side API for finding all keys defined by the input argument iterators. 116 | * @tparam InputIt Device-side iterator that can be converted to `key_type` 117 | * @tparam OutputIt Device-side iterator that can be converted to `mapped_type` 118 | * @param first An iterator defining the beginning of the input keys to find 119 | * @param last An iterator defining the end of the input keys to find 120 | * @param output_begin An iterator defining the beginning of the output buffer to store 121 | * the results into. The size of the buffer must match the number of queries defined by 122 | * the input iterators. 123 | * @param stream A CUDA stream where the insertion operation will take place 124 | */ 125 | template 126 | void find(InputIt first, InputIt last, OutputIt output_begin, cudaStream_t stream = 0); 127 | 128 | /** 129 | * @brief Device-side cooperative insertion API that inserts a single pair into the hash 130 | * map. This function is called by a single thread. 131 | * @param pair A key-value pair to insert into the hash map. 132 | * @return A boolean indicating success (true) or failure (false) of the insertion 133 | * operation. 134 | */ 135 | __device__ bool insert(value_type const& pair); 136 | 137 | /** 138 | * @brief Device-side cooperative find API that finds a single pair into the hash 139 | * map. 140 | * @param key A key to find in the hash map. The key must be the same 141 | * for all threads in the cooperative group tile 142 | * @return The value of the key if it exists in the map or the `sentinel_value` if the 143 | * key does not exist in the hash map 144 | */ 145 | __device__ mapped_type find(key_type const& key); 146 | 147 | /** 148 | * @brief Host-side API to randomize the hash functions used for the probing scheme. 149 | * This can be used when the hash table construction fails. The hash table must be 150 | * cleared after a call to this function. 151 | * @tparam RNG A pseudo-random number generator 152 | * @param rng An instantiation of the pseudo-random number generator 153 | */ 154 | template 155 | void randomize_hash_functions(RNG& rng); 156 | 157 | private: 158 | __device__ void set_build_success(const bool& success) { *d_build_success_ = success; } 159 | 160 | template 161 | friend __global__ void detail::kernels::insert_kernel(InputIt, InputIt, HashMap); 162 | 163 | template 164 | friend __global__ void detail::kernels::find_kernel(InputIt, 165 | InputIt, 166 | OutputIt, 167 | HashMap); 168 | 169 | std::size_t capacity_; 170 | key_type sentinel_key_{}; 171 | mapped_type sentinel_value_{}; 172 | allocator_type allocator_; 173 | atomic_pair_allocator_type atomic_pairs_allocator_; 174 | pool_allocator_type pool_allocator_; 175 | 176 | atomic_pair_type* d_table_{}; 177 | std::shared_ptr table_; 178 | 179 | bool* d_build_success_; 180 | std::shared_ptr build_success_; 181 | 182 | uint32_t max_cuckoo_chains_; 183 | 184 | Hash hf0_; 185 | Hash hf1_; 186 | Hash hf2_; 187 | Hash hf3_; 188 | 189 | std::size_t num_buckets_; 190 | }; 191 | 192 | } // namespace bght 193 | 194 | #include 195 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/cmd.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | std::string str_tolower(const std::string_view s) { 26 | std::string output(s.length(), ' '); 27 | std::transform(s.begin(), s.end(), output.begin(), [](unsigned char c) { 28 | return std::tolower(c); 29 | }); 30 | return output; 31 | } 32 | 33 | // Finds an argument value 34 | // auto arguments = std::vector(argv, argv + argc); 35 | // Example: 36 | // auto k = get_arg_value(arguments, "-flag") 37 | // auto arguments = std::vector(argv, argv + argc); 38 | template 39 | std::optional get_arg_value(const std::vector& arguments, 40 | const char* flag) { 41 | uint32_t first_argument = 1; 42 | for (uint32_t i = first_argument; i < arguments.size(); i++) { 43 | std::string_view argument = std::string_view(arguments[i]); 44 | auto key_start = argument.find_first_not_of("-"); 45 | auto value_start = argument.find("="); 46 | 47 | bool failed = argument.length() == 0; // there is an argument 48 | failed |= key_start == std::string::npos; // it has a - 49 | failed |= value_start == std::string::npos; // it has an = 50 | failed |= key_start > 2; // - or -- at beginning 51 | failed |= (value_start - key_start) == 0; // there is a key 52 | failed |= (argument.length() - value_start) == 1; // = is not last 53 | 54 | if (failed) { 55 | std::cout << "Invalid argument: " << argument << " ignored.\n"; 56 | std::cout << "Use: -flag=value " << std::endl; 57 | std::terminate(); 58 | } 59 | 60 | std::string_view argument_name = argument.substr(key_start, value_start - key_start); 61 | value_start++; // ignore the = 62 | std::string_view argument_value = 63 | argument.substr(value_start, argument.length() - key_start); 64 | 65 | if (argument_name == std::string_view(flag)) { 66 | if constexpr (std::is_same::value) { 67 | return static_cast(std::strtof(argument_value.data(), nullptr)); 68 | } else if constexpr (std::is_same::value) { 69 | return static_cast(std::strtod(argument_value.data(), nullptr)); 70 | } else if constexpr (std::is_same::value) { 71 | return static_cast(std::strtol(argument_value.data(), nullptr, 10)); 72 | } else if constexpr (std::is_same::value) { 73 | return static_cast(std::strtoll(argument_value.data(), nullptr, 10)); 74 | } else if constexpr (std::is_same::value) { 75 | return static_cast(std::strtoul(argument_value.data(), nullptr, 10)); 76 | } else if constexpr (std::is_same::value) { 77 | return static_cast(std::strtoull(argument_value.data(), nullptr, 10)); 78 | } else if constexpr (std::is_same::value) { 79 | return std::string(argument_value); 80 | } else if constexpr (std::is_same::value) { 81 | return str_tolower(argument_value) == "true"; 82 | } else { 83 | std::cout << "Unknown type" << std::endl; 84 | std::terminate(); 85 | } 86 | } 87 | } 88 | return {}; 89 | } 90 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/allocator.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | namespace bght { 21 | template 22 | struct cuda_deleter { 23 | void operator()(T* p) { cuda_try(cudaFree(p)); } 24 | }; 25 | 26 | template 27 | struct cuda_allocator { 28 | typedef std::size_t size_type; 29 | typedef std::ptrdiff_t difference_type; 30 | 31 | typedef T value_type; 32 | typedef T* pointer; 33 | typedef const T* const_pointer; 34 | typedef T& reference; 35 | typedef const T& const_reference; 36 | 37 | template 38 | struct rebind { 39 | typedef cuda_allocator other; 40 | }; 41 | cuda_allocator() = default; 42 | template 43 | constexpr cuda_allocator(const cuda_allocator&) noexcept {} 44 | T* allocate(std::size_t n) { 45 | void* p = nullptr; 46 | cuda_try(cudaMalloc(&p, n * sizeof(T))); 47 | return static_cast(p); 48 | } 49 | void deallocate(T* p, std::size_t n) noexcept { cuda_try(cudaFree(p)); } 50 | }; 51 | } // namespace bght 52 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/benchmark_metrics.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #ifdef COUNT_PROBES 20 | __device__ __managed__ uint32_t global_probes_count = 0; 21 | #define INCREMENT_PROBES_IN_TILE \ 22 | if (tile.thread_rank() == 0) \ 23 | atomicAdd(&global_probes_count, 1); 24 | #define INCREMENT_PROBES atomicAdd(&global_probes_count, 1); 25 | namespace bght { 26 | // uint32_t get_num_probes() { 27 | // cudaDeviceSynchronize(); 28 | // auto count = global_probes_count; 29 | // global_probes_count = 0; 30 | // cudaDeviceSynchronize(); 31 | // return count; 32 | // } 33 | } // namespace bght 34 | #else 35 | #define INCREMENT_PROBES_IN_TILE 36 | #define INCREMENT_PROBES 37 | namespace bght { 38 | // uint32_t get_num_probes() { 39 | // return 0; 40 | // } 41 | } // namespace bght 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/bucket.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | 21 | namespace bght { 22 | namespace detail { 23 | template 24 | struct bucket { 25 | bucket() = delete; 26 | DEVICE_QUALIFIER 27 | bucket(atomic_pair_type* ptr, const tile_type& tile) : ptr_(ptr), tile_(tile) {} 28 | 29 | DEVICE_QUALIFIER 30 | bucket(const bucket& other) : lane_pair_(other.lane_pair_), ptr_(other.ptr_) {} 31 | 32 | DEVICE_QUALIFIER 33 | void load(cuda::memory_order order = cuda::memory_order_seq_cst) { 34 | lane_pair_ = ptr_[tile_.thread_rank()].load(order); 35 | } 36 | DEVICE_QUALIFIER 37 | int compute_load(const pair_type& sentinel) { 38 | auto load_bitmap = tile_.ballot(lane_pair_.first != sentinel.first); 39 | int load = __popc(load_bitmap); 40 | return load; 41 | } 42 | // returns -1 if not found 43 | template 44 | DEVICE_QUALIFIER int find_key_location(const typename pair_type::first_type& key, 45 | const KeyEqual key_equal) { 46 | bool key_exist = key_equal(key, lane_pair_.first); 47 | auto key_exist_bmap = tile_.ballot(key_exist); 48 | int key_lane = __ffs(key_exist_bmap); 49 | return key_lane - 1; 50 | } 51 | DEVICE_QUALIFIER 52 | typename pair_type::second_type get_value_from_lane(int location) { 53 | return tile_.shfl(lane_pair_.second, location); 54 | } 55 | 56 | DEVICE_QUALIFIER 57 | bool weak_cas_at_location(const pair_type& pair, 58 | const int location, 59 | const pair_type& sentinel, 60 | cuda::memory_order success = cuda::memory_order_seq_cst, 61 | cuda::memory_order failure = cuda::memory_order_seq_cst) { 62 | pair_type expected = sentinel; 63 | pair_type desired = pair; 64 | bool cas_success = 65 | ptr_[location].compare_exchange_weak(expected, desired, success, failure); 66 | return cas_success; 67 | } 68 | 69 | DEVICE_QUALIFIER 70 | bool strong_cas_at_location(const pair_type& pair, 71 | const int location, 72 | const pair_type& sentinel, 73 | cuda::memory_order success = cuda::memory_order_seq_cst, 74 | cuda::memory_order failure = cuda::memory_order_seq_cst) { 75 | pair_type expected = sentinel; 76 | pair_type desired = pair; 77 | bool cas_success = 78 | ptr_[location].compare_exchange_strong(expected, desired, success, failure); 79 | return cas_success; 80 | } 81 | 82 | DEVICE_QUALIFIER 83 | pair_type exch_at_location(const pair_type& pair, 84 | const int location, 85 | cuda::memory_order order = cuda::memory_order_seq_cst) { 86 | auto old = ptr_[location].exchange(pair, order); 87 | return old; 88 | } 89 | 90 | private: 91 | pair_type lane_pair_; 92 | atomic_pair_type* ptr_; 93 | tile_type tile_; 94 | }; 95 | } // namespace detail 96 | } // namespace bght 97 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/cht_impl.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | namespace bght { 29 | template 35 | cht::cht(std::size_t capacity, 36 | Key empty_key_sentinel, 37 | T empty_value_sentinel, 38 | Allocator const& allocator) 39 | : capacity_{std::max(capacity, std::size_t{1})} 40 | , sentinel_key_{empty_key_sentinel} 41 | , sentinel_value_{empty_value_sentinel} 42 | , allocator_{allocator} 43 | , atomic_pairs_allocator_{allocator} 44 | , pool_allocator_{allocator} { 45 | num_buckets_ = capacity_; 46 | d_table_ = std::allocator_traits::allocate( 47 | atomic_pairs_allocator_, capacity_); 48 | table_ = 49 | std::shared_ptr(d_table_, bght::cuda_deleter()); 50 | 51 | d_build_success_ = 52 | std::allocator_traits::allocate(pool_allocator_, 1); 53 | build_success_ = std::shared_ptr(d_build_success_, bght::cuda_deleter()); 54 | 55 | value_type empty_pair{sentinel_key_, sentinel_value_}; 56 | 57 | thrust::fill(thrust::device, d_table_, d_table_ + capacity_, empty_pair); 58 | 59 | // maximum number of cuckoo chains 60 | double lg_input_size = (float)(log((double)capacity) / log(2.0)); 61 | const unsigned max_iter_const = 7; 62 | max_cuckoo_chains_ = static_cast(max_iter_const * lg_input_size); 63 | 64 | std::mt19937 rng(2); 65 | hf0_ = initialize_hf(rng); 66 | hf1_ = initialize_hf(rng); 67 | hf2_ = initialize_hf(rng); 68 | hf3_ = initialize_hf(rng); 69 | 70 | bool success = true; 71 | cuda_try(cudaMemcpy(d_build_success_, &success, sizeof(bool), cudaMemcpyHostToDevice)); 72 | } 73 | 74 | template 80 | cht::cht(const cht& other) 81 | : capacity_(other.capacity_) 82 | , sentinel_key_(other.sentinel_key_) 83 | , sentinel_value_(other.sentinel_value_) 84 | , allocator_(other.allocator_) 85 | , atomic_pairs_allocator_(other.atomic_pairs_allocator_) 86 | , pool_allocator_(other.pool_allocator_) 87 | , d_table_(other.d_table_) 88 | , table_(other.table_) 89 | , d_build_success_(other.d_build_success_) 90 | , build_success_(other.build_success_) 91 | , max_cuckoo_chains_(other.max_cuckoo_chains_) 92 | , hf0_(other.hf0_) 93 | , hf1_(other.hf1_) 94 | , hf2_(other.hf2_) 95 | , hf3_(other.hf3_) 96 | , num_buckets_(other.num_buckets_) {} 97 | 98 | template 104 | cht::~cht() {} 105 | 106 | template 112 | void cht::clear() { 113 | value_type empty_pair{sentinel_key_, sentinel_value_}; 114 | thrust::fill(thrust::device, d_table_, d_table_ + capacity_, empty_pair); 115 | bool success = true; 116 | cuda_try(cudaMemcpy(d_build_success_, &success, sizeof(bool), cudaMemcpyHostToDevice)); 117 | } 118 | 119 | template 125 | template 126 | bool cht::insert(InputIt first, 127 | InputIt last, 128 | cudaStream_t stream) { 129 | const auto num_keys = std::distance(first, last); 130 | 131 | const uint32_t block_size = 128; 132 | const uint32_t num_blocks = 133 | static_cast((num_keys + block_size - 1) / block_size); 134 | detail::kernels::insert_kernel<<>>( 135 | first, last, *this); 136 | // cuda_try(cudaPeekAtLastError()); 137 | 138 | bool success; 139 | cuda_try(cudaMemcpyAsync( 140 | &success, d_build_success_, sizeof(bool), cudaMemcpyDeviceToHost, stream)); 141 | return success; 142 | } 143 | 144 | template 150 | template 151 | void cht::find(InputIt first, 152 | InputIt last, 153 | OutputIt output_begin, 154 | cudaStream_t stream) { 155 | const auto num_keys = std::distance(first, last); 156 | 157 | const uint32_t block_size = 128; 158 | const uint32_t num_blocks = 159 | static_cast((num_keys + block_size - 1) / block_size); 160 | 161 | detail::kernels::find_kernel<<>>( 162 | first, last, output_begin, *this); 163 | // cuda_try(cudaPeekAtLastError()); 164 | } 165 | 166 | template 172 | __device__ bool bght::cht::insert( 173 | value_type const& pair) { 174 | auto bucket_id = hf0_(pair.first) % num_buckets_; 175 | 176 | uint32_t cuckoo_counter = 0; 177 | value_type sentinel_pair{sentinel_key_, sentinel_value_}; 178 | value_type insertion_pair = pair; 179 | do { 180 | auto old_pair = 181 | d_table_[bucket_id].exchange(insertion_pair, cuda::memory_order_relaxed); 182 | INCREMENT_PROBES 183 | if (old_pair.first == sentinel_key_) { 184 | return true; 185 | } else { 186 | auto bucket0 = hf0_(old_pair.first) % num_buckets_; 187 | auto bucket1 = hf1_(old_pair.first) % num_buckets_; 188 | auto bucket2 = hf2_(old_pair.first) % num_buckets_; 189 | auto bucket3 = hf3_(old_pair.first) % num_buckets_; 190 | 191 | auto new_bucket_id = bucket0; 192 | new_bucket_id = bucket_id == bucket2 ? bucket3 : new_bucket_id; 193 | new_bucket_id = bucket_id == bucket1 ? bucket2 : new_bucket_id; 194 | new_bucket_id = bucket_id == bucket0 ? bucket1 : new_bucket_id; 195 | 196 | bucket_id = new_bucket_id; 197 | insertion_pair = old_pair; 198 | } 199 | cuckoo_counter++; 200 | } while (cuckoo_counter < max_cuckoo_chains_); 201 | return false; 202 | } 203 | 204 | template 210 | __device__ bght::cht::mapped_type 211 | bght::cht::find(key_type const& key) { 212 | const int num_hfs = 4; 213 | auto bucket_id = hf0_(key) % num_buckets_; 214 | for (int hf = 0; hf < num_hfs; hf++) { 215 | auto pair = d_table_[bucket_id].load(cuda::memory_order_relaxed); 216 | INCREMENT_PROBES 217 | if (pair.first == key) { 218 | return pair.second; 219 | } else if (pair.first == sentinel_key_) { 220 | return sentinel_value_; 221 | } else { 222 | if (hf == 0) { 223 | bucket_id = hf1_(key) % num_buckets_; 224 | } else if (hf == 1) { 225 | bucket_id = hf2_(key) % num_buckets_; 226 | } else { 227 | bucket_id = hf3_(key) % num_buckets_; 228 | } 229 | } 230 | } 231 | 232 | return sentinel_value_; 233 | } 234 | 235 | template 241 | template 242 | void bght::cht::randomize_hash_functions( 243 | RNG& rng) { 244 | hf0_ = initialize_hf(rng); 245 | hf1_ = initialize_hf(rng); 246 | hf2_ = initialize_hf(rng); 247 | hf3_ = initialize_hf(rng); 248 | } 249 | } // namespace bght 250 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/cuda_helpers.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | namespace bght { 19 | #define _device_ __device__ __forceinline__ 20 | #define _host_device_ __device__ __host__ __forceinline__ 21 | #define _kernel_ __global__ 22 | #define DEVICE_QUALIFIER __device__ inline 23 | namespace detail { 24 | #define cuda_try(call) \ 25 | do { \ 26 | cudaError_t err = call; \ 27 | if (err != cudaSuccess) { \ 28 | printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ 29 | std::terminate(); \ 30 | } \ 31 | } while (0) 32 | 33 | _device_ void cuda_assert(bool expression_result, char* message = nullptr) { 34 | if (!expression_result) { 35 | if (message && (threadIdx.x & 0x1f == 0)) { 36 | printf("assert failed: %s", message); 37 | } 38 | //__trap(); 39 | asm("trap;"); 40 | } 41 | } 42 | } // namespace detail 43 | 44 | // void set_device(int device_id) { 45 | // int device_count; 46 | // cudaGetDeviceCount(&device_count); 47 | // cudaDeviceProp devProp; 48 | // if (device_id < device_count) { 49 | // cudaSetDevice(device_id); 50 | // cudaGetDeviceProperties(&devProp, device_id); 51 | // std::cout << "Device[" << device_id << "]: " << devProp.name << std::endl; 52 | // } else { 53 | // std::cout << "No capable CUDA device found." << std::endl; 54 | // std::terminate(); 55 | // } 56 | // } 57 | } // namespace bght 58 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/hash_functions.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | namespace bght { 19 | template 20 | struct universal_hash { 21 | using key_type = Key; 22 | using result_type = Key; 23 | __host__ __device__ constexpr universal_hash(uint32_t hash_x, uint32_t hash_y) 24 | : hash_x_(hash_x), hash_y_(hash_y) {} 25 | 26 | constexpr result_type __host__ __device__ operator()(const key_type& key) const { 27 | return (((hash_x_ ^ key) + hash_y_) % prime_divisor); 28 | } 29 | 30 | universal_hash(const universal_hash&) = default; 31 | universal_hash() = default; 32 | universal_hash(universal_hash&&) = default; 33 | universal_hash& operator=(universal_hash const&) = default; 34 | universal_hash& operator=(universal_hash&&) = default; 35 | ~universal_hash() = default; 36 | 37 | static constexpr uint32_t prime_divisor = 4294967291u; 38 | 39 | private: 40 | uint32_t hash_x_; 41 | uint32_t hash_y_; 42 | }; 43 | 44 | template 45 | Hash initialize_hf(RNG& rng) { 46 | uint32_t x = rng() % Hash::prime_divisor; 47 | if (x < 1u) { 48 | x = 1; 49 | } 50 | uint32_t y = rng() % Hash::prime_divisor; 51 | return Hash(x, y); 52 | } 53 | } // namespace bght 54 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/kernels.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | 22 | namespace bght { 23 | namespace detail { 24 | namespace kernels { 25 | template 26 | __global__ void tiled_insert_kernel(InputIt first, InputIt last, HashMap map) { 27 | // construct the tile 28 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 29 | auto block = cooperative_groups::this_thread_block(); 30 | auto tile = cooperative_groups::tiled_partition(block); 31 | 32 | auto count = last - first; 33 | if ((thread_id - tile.thread_rank()) >= count) { 34 | return; 35 | } 36 | 37 | bool do_op = false; 38 | typename HashMap::value_type insertion_pair{}; 39 | 40 | // load the input 41 | if (thread_id < count) { 42 | insertion_pair = first[thread_id]; 43 | do_op = true; 44 | } 45 | 46 | bool success = true; 47 | // Do the insertion 48 | auto work_queue = tile.ballot(do_op); 49 | while (work_queue) { 50 | auto cur_rank = __ffs(work_queue) - 1; 51 | auto cur_pair = tile.shfl(insertion_pair, cur_rank); 52 | bool insertion_success = map.insert(cur_pair, tile); 53 | 54 | if (tile.thread_rank() == cur_rank) { 55 | do_op = false; 56 | success = insertion_success; 57 | } 58 | work_queue = tile.ballot(do_op); 59 | } 60 | 61 | if (!tile.all(success)) { 62 | *map.d_build_success_ = false; 63 | } 64 | } 65 | 66 | template 67 | __global__ void tiled_find_kernel(InputIt first, 68 | InputIt last, 69 | OutputIt output_begin, 70 | HashMap map) { 71 | // construct the tile 72 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 73 | auto block = cooperative_groups::this_thread_block(); 74 | auto tile = cooperative_groups::tiled_partition(block); 75 | 76 | auto count = last - first; 77 | if ((thread_id - tile.thread_rank()) >= count) { 78 | return; 79 | } 80 | 81 | bool do_op = false; 82 | typename HashMap::key_type find_key; 83 | typename HashMap::mapped_type result; 84 | 85 | // load the input 86 | if (thread_id < count) { 87 | find_key = first[thread_id]; 88 | do_op = true; 89 | } 90 | 91 | // Do the insertion 92 | auto work_queue = tile.ballot(do_op); 93 | while (work_queue) { 94 | auto cur_rank = __ffs(work_queue) - 1; 95 | auto cur_key = tile.shfl(find_key, cur_rank); 96 | 97 | typename HashMap::mapped_type find_result = map.find(cur_key, tile); 98 | 99 | if (tile.thread_rank() == cur_rank) { 100 | result = find_result; 101 | do_op = false; 102 | } 103 | work_queue = tile.ballot(do_op); 104 | } 105 | 106 | if (thread_id < count) { 107 | output_begin[thread_id] = result; 108 | } 109 | } 110 | 111 | template 112 | __global__ void insert_kernel(InputIt first, InputIt last, HashMap map) { 113 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 114 | auto count = last - first; 115 | 116 | if (thread_id < count) { 117 | auto insertion_pair = first[thread_id]; 118 | bool success = map.insert(insertion_pair); 119 | if (!success) { 120 | map.set_build_success(false); 121 | } 122 | } 123 | } 124 | 125 | template 126 | __global__ void find_kernel(InputIt first, 127 | InputIt last, 128 | OutputIt output_begin, 129 | HashMap map) { 130 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 131 | auto count = last - first; 132 | 133 | if (thread_id < count) { 134 | auto find_key = first[thread_id]; 135 | auto result = map.find(find_key); 136 | output_begin[thread_id] = result; 137 | } 138 | } 139 | 140 | template 141 | __global__ void count_kernel(const InputT count_key, std::size_t* count, HashMap map) { 142 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 143 | typedef cub::BlockReduce BlockReduce; 144 | __shared__ typename BlockReduce::TempStorage temp_storage; 145 | 146 | std::size_t match = 0; 147 | if (thread_id < map.capacity_) { 148 | const auto key = map.d_table_[thread_id].load(cuda::memory_order_relaxed).first; 149 | match = (key == count_key); 150 | } 151 | std::size_t block_num_matches = BlockReduce(temp_storage).Sum(match); 152 | if (threadIdx.x == 0) { 153 | auto sum = atomicAdd((unsigned long long int*)count, 154 | (unsigned long long int)block_num_matches); 155 | } 156 | } 157 | 158 | } // namespace kernels 159 | } // namespace detail 160 | } // namespace bght 161 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/pair.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | 21 | namespace bght { 22 | template () != 0> 23 | struct alignas(detail::pair_alignment()) padded_pair { 24 | using first_type = T1; 25 | using second_type = T2; 26 | T1 first; 27 | T2 second; 28 | padded_pair() = default; 29 | ~padded_pair() = default; 30 | padded_pair(padded_pair const&) = default; 31 | padded_pair(padded_pair&&) = default; 32 | padded_pair& operator=(padded_pair const&) = default; 33 | padded_pair& operator=(padded_pair&&) = default; 34 | 35 | __host__ __device__ inline bool operator==(const padded_pair& rhs) { 36 | return (this->first == rhs.first) && (this->second == rhs.second); 37 | } 38 | __host__ __device__ inline bool operator!=(const padded_pair& rhs) { 39 | return !(*this == rhs); 40 | } 41 | 42 | __host__ __device__ constexpr padded_pair(T1 const& t, T2 const& u) 43 | : first{t}, second{u} {} 44 | }; 45 | 46 | template 47 | struct alignas(detail::pair_alignment()) padded_pair { 48 | using first_type = T1; 49 | using second_type = T2; 50 | T1 first; 51 | T2 second; 52 | 53 | padded_pair() = default; 54 | ~padded_pair() = default; 55 | padded_pair(padded_pair const&) = default; 56 | padded_pair(padded_pair&&) = default; 57 | padded_pair& operator=(padded_pair const&) = default; 58 | padded_pair& operator=(padded_pair&&) = default; 59 | 60 | __host__ __device__ inline bool operator==(const padded_pair& rhs) { 61 | return (this->first == rhs.first) && (this->second == rhs.second); 62 | } 63 | __host__ __device__ inline bool operator!=(const padded_pair& rhs) { 64 | return !(*this == rhs); 65 | } 66 | 67 | __host__ __device__ constexpr padded_pair(T1 const& t, T2 const& u) 68 | : first{t}, second{u} {} 69 | 70 | private: 71 | char padding[detail::padding_size()] = {0}; 72 | }; 73 | 74 | template 75 | using pair = padded_pair; 76 | 77 | template 78 | struct equal_to { 79 | constexpr bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; } 80 | }; 81 | } // namespace bght 82 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/pair_detail.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | namespace bght { 20 | namespace detail { 21 | template 22 | constexpr std::size_t next_alignment() { 23 | constexpr std::size_t n = sizeof(T); 24 | if (n <= 4) 25 | return 4; 26 | if (n <= 8) 27 | return 8; 28 | return 16; 29 | } 30 | constexpr std::size_t next_alignment(std::size_t n) { 31 | if (n <= 4) 32 | return 4; 33 | if (n <= 8) 34 | return 8; 35 | return 16; 36 | } 37 | 38 | template 39 | constexpr std::size_t pair_size() { 40 | return sizeof(T1) + sizeof(T2); 41 | } 42 | 43 | template 44 | constexpr std::size_t pair_alignment() { 45 | return next_alignment(pair_size()); 46 | } 47 | 48 | template 49 | constexpr std::size_t padding_size() { 50 | constexpr auto psz = pair_size(); 51 | constexpr auto apsz = next_alignment(pair_size()); 52 | if (psz > apsz) { 53 | constexpr auto nsz = (1ull + (psz / apsz)) * apsz; 54 | return nsz - psz; 55 | } 56 | return apsz - psz; 57 | } 58 | } // namespace detail 59 | } // namespace bght 60 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/ptx.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | namespace bght { 18 | namespace detail { 19 | namespace bits { 20 | // Bit Field Extract. 21 | __device__ __forceinline__ int bfe(uint32_t src, int num_bits) { 22 | unsigned mask; 23 | asm("bfe.u32 %0, %1, 0, %2;" : "=r"(mask) : "r"(src), "r"(num_bits)); 24 | return mask; 25 | } 26 | 27 | // Find most significant non - sign bit. 28 | // bfind(0) = -1, bfind(1) = 0 29 | __device__ __forceinline__ int bfind(uint32_t src) { 30 | int msb; 31 | asm("bfind.u32 %0, %1;" : "=r"(msb) : "r"(src)); 32 | return msb; 33 | } 34 | __device__ __forceinline__ int bfind(uint64_t src) { 35 | int msb; 36 | asm("bfind.u64 %0, %1;" : "=r"(msb) : "l"(src)); 37 | return msb; 38 | } 39 | }; // namespace bits 40 | } // namespace detail 41 | } // namespace bght 42 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/rng.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | namespace bght { 19 | namespace detail { 20 | struct mars_rng_32 { 21 | uint32_t y; 22 | __host__ __device__ constexpr mars_rng_32() : y(2463534242) {} 23 | constexpr uint32_t __host__ __device__ operator()() { 24 | y ^= (y << 13); 25 | y = (y >> 17); 26 | return (y ^= (y << 5)); 27 | } 28 | }; 29 | } // namespace detail 30 | } // namespace bght 31 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/genzipf.hpp: -------------------------------------------------------------------------------- 1 | //= Author: Kenneth J. Christensen = 2 | //= University of South Florida = 3 | //= WWW: http://www.csee.usf.edu/~christen = 4 | //= Email: christen@csee.usf.edu = 5 | //=-------------------------------------------------------------------------= 6 | 7 | //========================================================================= 8 | //= Multiplicative LCG for generating uniform(0.0, 1.0) random numbers = 9 | //= - x_n = 7^5*x_(n-1)mod(2^31 - 1) = 10 | //= - With x seeded to 1 the 10000th x value should be 1043618065 = 11 | //= - From R. Jain, "The Art of Computer Systems Performance Analysis," = 12 | //= John Wiley & Sons, 1991. (Page 443, Figure 26.2) = 13 | //========================================================================= 14 | #pragma once 15 | #include 16 | #include 17 | 18 | double rand_val(int seed) { 19 | const long a = 16807; // Multiplier 20 | const long m = 2147483647; // Modulus 21 | const long q = 127773; // m div a 22 | const long r = 2836; // m mod a 23 | static long x; // Random int value 24 | long x_div_q; // x divided by q 25 | long x_mod_q; // x modulo q 26 | long x_new; // New x value 27 | 28 | // Set the seed if argument is non-zero and then return zero 29 | if (seed > 0) { 30 | x = seed; 31 | return (0.0); 32 | } 33 | 34 | // RNG using integer arithmetic 35 | x_div_q = x / q; 36 | x_mod_q = x % q; 37 | x_new = (a * x_mod_q) - (r * x_div_q); 38 | if (x_new > 0) 39 | x = x_new; 40 | else 41 | x = x_new + m; 42 | 43 | // Return a random value between 0.0 and 1.0 44 | return ((double)x / m); 45 | } 46 | 47 | uint32_t zipf(double alpha, uint32_t n) { 48 | static bool first = true; // Static first time flag 49 | static double c = 0; // Normalization constant 50 | static double* sum_probs; // Pre-calculated sum of probabilities 51 | double z; // Uniform random number (0 < z < 1) 52 | uint32_t zipf_value; // Computed exponential value to be returned 53 | uint32_t i; // Loop counter 54 | uint32_t low, high, mid; // Binary-search bounds 55 | 56 | // Compute normalization constant on first call only 57 | if (first == true) { 58 | for (i = 1; i <= n; i++) 59 | c = c + (1.0 / pow((double)i, alpha)); 60 | c = 1.0 / c; 61 | 62 | sum_probs = reinterpret_cast(std::malloc((n + 1) * sizeof(*sum_probs))); 63 | sum_probs[0] = 0; 64 | for (i = 1; i <= n; i++) { 65 | sum_probs[i] = sum_probs[i - 1] + c / pow((double)i, alpha); 66 | // std::cout << i << "," << sum_probs[i] << std::endl; 67 | } 68 | first = false; 69 | std::cout << "Computed probabilities" << std::endl; 70 | } 71 | 72 | // Pull a uniform random number (0 < z < 1) 73 | do { 74 | z = rand_val(0); 75 | } while ((z == 0) || (z == 1)); 76 | 77 | // Map z to the value 78 | low = 1; 79 | high = n; 80 | do { 81 | mid = floor((low + high) / 2); 82 | if (sum_probs[mid] >= z && sum_probs[mid - 1] < z) { 83 | zipf_value = mid; 84 | break; 85 | } else if (sum_probs[mid] >= z) { 86 | high = mid - 1; 87 | } else { 88 | low = mid + 1; 89 | } 90 | } while (low <= high); 91 | 92 | // Assert that zipf_value is between 1 and N 93 | assert((zipf_value >= 1) && (zipf_value <= n)); 94 | 95 | return zipf_value; 96 | } -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/gpu_timer.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | 20 | struct gpu_timer { 21 | gpu_timer(cudaStream_t stream = 0) : start_{}, stop_{}, stream_(stream) { 22 | cudaEventCreate(&start_); 23 | cudaEventCreate(&stop_); 24 | } 25 | void start_timer() { cudaEventRecord(start_, stream_); } 26 | void stop_timer() { cudaEventRecord(stop_, stream_); } 27 | float get_elapsed_ms() { 28 | compute_ms(); 29 | return elapsed_time_; 30 | } 31 | 32 | float get_elapsed_s() { 33 | compute_ms(); 34 | return elapsed_time_ * 0.001f; 35 | } 36 | ~gpu_timer() { 37 | cudaEventDestroy(start_); 38 | cudaEventDestroy(stop_); 39 | }; 40 | 41 | private: 42 | void compute_ms() { 43 | cudaEventSynchronize(stop_); 44 | cudaEventElapsedTime(&elapsed_time_, start_, stop_); 45 | } 46 | cudaEvent_t start_, stop_; 47 | cudaStream_t stream_; 48 | float elapsed_time_ = 0.0f; 49 | }; -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/perf_report.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | void std_cout_perf_report(float insertion_s, 20 | float find_s, 21 | std::size_t num_insertions, 22 | std::size_t num_finds) { 23 | std::cout << "inserted: " << num_insertions << " keys" << '\n'; 24 | std::cout << "finds: " << num_finds << " keys" << '\n'; 25 | 26 | double insertion_rate = double(num_insertions) * 1e-6 / insertion_s; 27 | double find_rate = double(num_finds) * 1e-6 / find_s; 28 | 29 | std::cout << "insert_rate: " << insertion_rate << " Mkey/s" << '\n'; 30 | std::cout << "find_rate: " << find_rate << " Mkey/s" << '\n'; 31 | } 32 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/rkg.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | namespace rkg { 24 | template 25 | value_type generate_value(key_type in) { 26 | return in + 1; 27 | } 28 | 29 | template 30 | void generate_uniform_unique_pairs(std::vector& keys, 31 | std::vector& values, 32 | size_type num_keys, 33 | bool cache = false, 34 | key_type min_key = 0) { 35 | keys.resize(num_keys); 36 | values.resize(num_keys); 37 | unsigned seed = 1; 38 | // bool cache = true; 39 | std::string dataset_dir = "dataset"; 40 | std::string dataset_name = std::to_string(num_keys) + "_" + std::to_string(seed); 41 | std::string dataset_path = dataset_dir + "/" + dataset_name; 42 | if (cache) { 43 | if (std::experimental::filesystem::exists(dataset_dir)) { 44 | if (std::experimental::filesystem::exists(dataset_path)) { 45 | std::cout << "Reading cached keys.." << std::endl; 46 | std::ifstream dataset(dataset_path, std::ios::binary); 47 | dataset.read((char*)keys.data(), sizeof(key_type) * num_keys); 48 | dataset.read((char*)values.data(), sizeof(value_type) * num_keys); 49 | dataset.close(); 50 | return; 51 | } 52 | } else { 53 | std::experimental::filesystem::create_directory(dataset_dir); 54 | } 55 | } 56 | std::random_device rd; 57 | std::mt19937 rng(seed); 58 | auto max_key = std::numeric_limits::max() - 1; 59 | std::uniform_int_distribution uni(min_key, max_key); 60 | std::unordered_set unique_keys; 61 | while (unique_keys.size() < num_keys) { 62 | unique_keys.insert(uni(rng)); 63 | // unique_keys.insert(unique_keys.size() + 1); 64 | } 65 | std::copy(unique_keys.cbegin(), unique_keys.cend(), keys.begin()); 66 | std::shuffle(keys.begin(), keys.end(), rng); 67 | 68 | #ifdef _WIN32 69 | // OpenMP + windows don't allow unsigned loops 70 | for (uint32_t i = 0; i < unique_keys.size(); i++) { 71 | values[i] = generate_value(keys[i]); 72 | } 73 | #else 74 | 75 | for (uint32_t i = 0; i < unique_keys.size(); i++) { 76 | values[i] = generate_value(keys[i]); 77 | } 78 | #endif 79 | 80 | if (cache) { 81 | std::cout << "Caching.." << std::endl; 82 | std::ofstream dataset(dataset_path, std::ios::binary); 83 | dataset.write((char*)keys.data(), sizeof(key_type) * num_keys); 84 | dataset.write((char*)values.data(), sizeof(value_type) * num_keys); 85 | dataset.close(); 86 | } 87 | } 88 | 89 | template 90 | void generate_uniform_unique_keys(std::vector& keys, size_type num_keys) { 91 | keys.resize(num_keys); 92 | unsigned seed = 1; 93 | std::random_device rd; 94 | std::mt19937 rng(seed); 95 | auto max_key = std::numeric_limits::max() - 1; 96 | std::uniform_int_distribution uni(0, max_key); 97 | std::unordered_set unique_keys; 98 | while (unique_keys.size() < num_keys) { 99 | unique_keys.insert(uni(rng)); 100 | } 101 | std::copy(unique_keys.cbegin(), unique_keys.cend(), keys.begin()); 102 | std::shuffle(keys.begin(), keys.end(), rng); 103 | } 104 | } // namespace rkg 105 | -------------------------------------------------------------------------------- /sampling_server/src/include/system_config.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /*-----------------------------iostack parameters---------------------------*/ 4 | #include 5 | #define REG_SIZE 0x4000 // BAR 0 mapped size 6 | #define REG_CC 0x14 // addr: controller configuration 7 | #define REG_CC_EN 0x1 // mask: enable controller 8 | #define REG_CSTS 0x1c // addr: controller status 9 | #define REG_CSTS_RDY 0x1 // mask: controller ready 10 | #define REG_AQA 0x24 // addr: admin queue attributes 11 | #define REG_ASQ 0x28 // addr: admin submission queue base addr 12 | #define REG_ACQ 0x30 // addr: admin completion queue base addr 13 | #define REG_SQTDBL 0x1000 // addr: submission queue 0 tail doorbell 14 | #define REG_CQHDBL 0x1004 // addr: completion queue 0 sq_tail doorbell 15 | #define DBL_STRIDE 8 16 | #define PHASE_MASK 0x10000 // mask: phase tag 17 | #define HOST_PGSZ 0x1000 18 | #define DEVICE_PGSZ 0x10000 19 | #define CID_MASK 0xffff // mask: command id 20 | #define SC_MASK 0xff // mask: status code 21 | #define BROADCAST_NSID 0xffffffff // broadcast namespace id 22 | #define OPCODE_SET_FEATURES 0x09 23 | #define OPCODE_CREATE_IO_CQ 0x05 24 | #define OPCODE_CREATE_IO_SQ 0x01 25 | #define OPCODE_READ 0x02 26 | #define OPCODE_WRITE 0x01 27 | #define FID_NUM_QUEUES 0x07 28 | #define LB_SIZE 0x200 29 | #define RW_RETRY_MASK 0x80000000 30 | #define SQ_ITEM_SIZE 64 31 | #define WARP_SIZE 32 32 | #define SQ_HEAD_MASK 0xffff 33 | 34 | #define MAX_IO_SIZE 4096 35 | #define ITEM_SIZE 512 36 | #define MAX_ITEMS (MAX_IO_SIZE / ITEM_SIZE) 37 | #define NUM_THREADS_PER_BLOCK 512 38 | #define ADMIN_QUEUE_DEPTH 64 39 | #define QUEUE_DEPTH 4096 40 | #define QUEUE_IOBUF_SIZE (MAX_IO_SIZE * QUEUE_DEPTH) 41 | #define NUM_PRP_ENTRIES (MAX_IO_SIZE / HOST_PGSZ) 42 | #define PRP_SIZE (NUM_PRP_ENTRIES * sizeof(uint64_t)) 43 | #define NUM_LBS_PER_SSD 0x100000000 44 | #define MAX_SSDS_SUPPORTED 16 45 | 46 | 47 | #define INTERBATCH_CON 2 //inter-batch pipeline concurrency 48 | #define INTRABATCH_CON 3 //intra-batch pipeline concurrency 49 | 50 | #define MAX_DEVICE 8 51 | #define MEMORY_USAGE 7 52 | #define TRAINMODE 0 53 | #define VALIDMODE 1 54 | #define TESTMODE 2 55 | 56 | #define CACHEMISS_FLAG -2 57 | #define CACHECPU_FLAG -1 58 | 59 | #define CHECK(ans) gpuAssert((ans), __FILE__, __LINE__) 60 | 61 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) 62 | { 63 | if (code != cudaSuccess) 64 | { 65 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 66 | if (abort) 67 | exit(1); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /sampling_server/src/main.cu: -------------------------------------------------------------------------------- 1 | #include "server.h" 2 | #include 3 | #include 4 | 5 | int main(int argc, char** argv){ 6 | 7 | std::cout<<"Start Sampling Server\n"; 8 | Server* server = NewGPUServer(); 9 | std::vector fanout; 10 | fanout.push_back(25); 11 | fanout.push_back(10); 12 | server->Initialize(atoi(argv[1]), fanout, 1);//gpu number, default 1; in memory, default true 13 | server->PreSc(atoi(argv[2]));//cache aggregate mode, default 0 14 | server->Run(); 15 | server->Finalize(); 16 | 17 | } -------------------------------------------------------------------------------- /sampling_server/src/storage/feature_storage.cu: -------------------------------------------------------------------------------- 1 | #include "feature_storage.cuh" 2 | #include "feature_storage_impl.cuh" 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | class CompleteFeatureStorage : public FeatureStorage{ 12 | public: 13 | CompleteFeatureStorage(){ 14 | } 15 | 16 | virtual ~CompleteFeatureStorage(){}; 17 | 18 | void Build(BuildInfo* info, int in_memory_mode) override { 19 | int32_t partition_count = info->partition_count; 20 | total_num_nodes_ = info->total_num_nodes; 21 | float_feature_len_ = info->float_feature_len; 22 | float* host_float_feature = info->host_float_feature; 23 | 24 | if(in_memory_mode){ 25 | cudaHostGetDevicePointer(&float_feature_, host_float_feature, 0); 26 | } 27 | cudaCheckError(); 28 | 29 | training_set_num_.resize(partition_count); 30 | training_set_ids_.resize(partition_count); 31 | training_labels_.resize(partition_count); 32 | 33 | validation_set_num_.resize(partition_count); 34 | validation_set_ids_.resize(partition_count); 35 | validation_labels_.resize(partition_count); 36 | 37 | testing_set_num_.resize(partition_count); 38 | testing_set_ids_.resize(partition_count); 39 | testing_labels_.resize(partition_count); 40 | 41 | partition_count_ = partition_count; 42 | 43 | for(int32_t i = 0; i < partition_count_; i++){ 44 | int32_t part_id = i; 45 | training_set_num_[part_id] = info->training_set_num[part_id]; 46 | validation_set_num_[part_id] = info->validation_set_num[part_id]; 47 | testing_set_num_[part_id] = info->testing_set_num[part_id]; 48 | 49 | cudaSetDevice(part_id); 50 | cudaCheckError(); 51 | 52 | int32_t* train_ids; 53 | cudaMalloc(&train_ids, training_set_num_[part_id] * sizeof(int32_t)); 54 | cudaMemcpy(train_ids, info->training_set_ids[part_id].data(), training_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 55 | training_set_ids_[part_id] = train_ids; 56 | cudaCheckError(); 57 | 58 | int32_t* valid_ids; 59 | cudaMalloc(&valid_ids, validation_set_num_[part_id] * sizeof(int32_t)); 60 | cudaMemcpy(valid_ids, info->validation_set_ids[part_id].data(), validation_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 61 | validation_set_ids_[part_id] = valid_ids; 62 | cudaCheckError(); 63 | 64 | int32_t* test_ids; 65 | cudaMalloc(&test_ids, testing_set_num_[part_id] * sizeof(int32_t)); 66 | cudaMemcpy(test_ids, info->testing_set_ids[part_id].data(), testing_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 67 | testing_set_ids_[part_id] = test_ids; 68 | cudaCheckError(); 69 | 70 | int32_t* train_labels; 71 | cudaMalloc(&train_labels, training_set_num_[part_id] * sizeof(int32_t)); 72 | cudaMemcpy(train_labels, info->training_labels[part_id].data(), training_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 73 | training_labels_[part_id] = train_labels; 74 | cudaCheckError(); 75 | 76 | int32_t* valid_labels; 77 | cudaMalloc(&valid_labels, validation_set_num_[part_id] * sizeof(int32_t)); 78 | cudaMemcpy(valid_labels, info->validation_labels[part_id].data(), validation_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 79 | validation_labels_[part_id] = valid_labels; 80 | cudaCheckError(); 81 | 82 | int32_t* test_labels; 83 | cudaMalloc(&test_labels, testing_set_num_[part_id] * sizeof(int32_t)); 84 | cudaMemcpy(test_labels, info->testing_labels[part_id].data(), testing_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 85 | testing_labels_[part_id] = test_labels; 86 | cudaCheckError(); 87 | 88 | } 89 | 90 | }; 91 | 92 | void Finalize() override { 93 | cudaFreeHost(float_feature_); 94 | for(int32_t i = 0; i < partition_count_; i++){ 95 | cudaSetDevice(i); 96 | cudaFree(training_set_ids_[i]); 97 | cudaFree(validation_set_ids_[i]); 98 | cudaFree(testing_set_ids_[i]); 99 | cudaFree(training_labels_[i]); 100 | cudaFree(validation_labels_[i]); 101 | cudaFree(testing_labels_[i]); 102 | } 103 | } 104 | 105 | int32_t* GetTrainingSetIds(int32_t part_id) const override { 106 | return training_set_ids_[part_id]; 107 | } 108 | int32_t* GetValidationSetIds(int32_t part_id) const override { 109 | return validation_set_ids_[part_id]; 110 | } 111 | int32_t* GetTestingSetIds(int32_t part_id) const override { 112 | return testing_set_ids_[part_id]; 113 | } 114 | 115 | int32_t* GetTrainingLabels(int32_t part_id) const override { 116 | return training_labels_[part_id]; 117 | }; 118 | int32_t* GetValidationLabels(int32_t part_id) const override { 119 | return validation_labels_[part_id]; 120 | } 121 | int32_t* GetTestingLabels(int32_t part_id) const override { 122 | return testing_labels_[part_id]; 123 | } 124 | 125 | int32_t TrainingSetSize(int32_t part_id) const override { 126 | return training_set_num_[part_id]; 127 | } 128 | int32_t ValidationSetSize(int32_t part_id) const override { 129 | return validation_set_num_[part_id]; 130 | } 131 | int32_t TestingSetSize(int32_t part_id) const override { 132 | return testing_set_num_[part_id]; 133 | } 134 | 135 | int32_t TotalNodeNum() const override { 136 | return total_num_nodes_; 137 | } 138 | 139 | float* GetAllFloatFeature() const override { 140 | return float_feature_; 141 | } 142 | int32_t GetFloatFeatureLen() const override { 143 | return float_feature_len_; 144 | } 145 | 146 | void IOSubmit(int32_t* sampled_ids, int32_t* cache_index, 147 | int32_t* node_counter, float* dst_float_buffer, 148 | int32_t op_id, int32_t dev_id, cudaStream_t strm_hdl) override { 149 | //TODO 150 | } 151 | 152 | void IOComplete() override { 153 | //TODO 154 | } 155 | 156 | private: 157 | std::vector training_set_num_; 158 | std::vector validation_set_num_; 159 | std::vector testing_set_num_; 160 | 161 | std::vector training_set_ids_; 162 | std::vector validation_set_ids_; 163 | std::vector testing_set_ids_; 164 | 165 | std::vector training_labels_; 166 | std::vector validation_labels_; 167 | std::vector testing_labels_; 168 | 169 | int32_t partition_count_; 170 | int32_t total_num_nodes_; 171 | float* float_feature_; 172 | int32_t float_feature_len_; 173 | 174 | friend FeatureStorage* NewCompleteFeatureStorage(); 175 | }; 176 | 177 | extern "C" 178 | FeatureStorage* NewCompleteFeatureStorage(){ 179 | CompleteFeatureStorage* ret = new CompleteFeatureStorage(); 180 | return ret; 181 | } 182 | -------------------------------------------------------------------------------- /sampling_server/src/storage/feature_storage.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FEATURE_STORAGE_H_ 2 | #define FEATURE_STORAGE_H_ 3 | 4 | #include "buildinfo.h" 5 | 6 | class FeatureStorage { 7 | public: 8 | virtual ~FeatureStorage() = default; 9 | 10 | virtual void Build(BuildInfo* info, int in_memory_mode) = 0; 11 | virtual void Finalize() = 0; 12 | 13 | virtual int32_t* GetTrainingSetIds(int32_t part_id) const = 0; 14 | virtual int32_t* GetValidationSetIds(int32_t part_id) const = 0; 15 | virtual int32_t* GetTestingSetIds(int32_t part_id) const = 0; 16 | 17 | virtual int32_t* GetTrainingLabels(int32_t part_id) const = 0; 18 | virtual int32_t* GetValidationLabels(int32_t part_id) const = 0; 19 | virtual int32_t* GetTestingLabels(int32_t part_id) const = 0; 20 | 21 | virtual int32_t TrainingSetSize(int32_t part_id) const = 0; 22 | virtual int32_t ValidationSetSize(int32_t part_id) const = 0; 23 | virtual int32_t TestingSetSize(int32_t part_id) const = 0; 24 | 25 | virtual int32_t TotalNodeNum() const = 0; 26 | virtual float* GetAllFloatFeature() const = 0; 27 | virtual int32_t GetFloatFeatureLen() const = 0; 28 | 29 | virtual void IOSubmit(int32_t* sampled_ids, int32_t* cache_index, 30 | int32_t* node_counter, float* dst_float_buffer, 31 | int32_t op_id, int32_t dev_id, cudaStream_t strm_hdl) = 0; 32 | 33 | virtual void IOComplete() = 0; 34 | }; 35 | 36 | extern "C" 37 | FeatureStorage* NewCompleteFeatureStorage(); 38 | 39 | #endif -------------------------------------------------------------------------------- /sampling_server/src/storage/feature_storage_impl.cuh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef FEATURE_STORAGE_IMPL_H_ 3 | #define FEATURE_STORAGE_IMPL_H_ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | // Macro for checking cuda errors following a cuda launch or api call 10 | #define cudaCheckError() \ 11 | { \ 12 | cudaError_t e = cudaGetLastError(); \ 13 | if (e != cudaSuccess) { \ 14 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 15 | cudaGetErrorString(e)); \ 16 | exit(EXIT_FAILURE); \ 17 | } \ 18 | } 19 | 20 | 21 | #endif -------------------------------------------------------------------------------- /sampling_server/src/storage/graph_storage.cu: -------------------------------------------------------------------------------- 1 | #include "graph_storage.cuh" 2 | #include "graph_storage_impl.cuh" 3 | 4 | class CompleteGraphStorage : public GraphStorage { 5 | public: 6 | CompleteGraphStorage() { 7 | } 8 | 9 | virtual ~CompleteGraphStorage() { 10 | } 11 | 12 | void Build(BuildInfo* info) override { 13 | int32_t partition_count = info->partition_count; 14 | partition_count_ = partition_count; 15 | node_num_ = info->total_num_nodes; 16 | edge_num_ = info->total_edge_num; 17 | cache_edge_num_ = info->cache_edge_num; 18 | 19 | csr_node_index_.resize(partition_count_); 20 | csr_dst_node_ids_.resize(partition_count_); 21 | partition_index_.resize(partition_count_); 22 | partition_offset_.resize(partition_count_); 23 | 24 | d_global_count_.resize(partition_count); 25 | h_global_count_.resize(partition_count); 26 | h_cache_hit_.resize(partition_count); 27 | find_iter_.resize(partition_count); 28 | h_batch_size_.resize(partition_count); 29 | 30 | for(int32_t i = 0; i < partition_count; i++){ 31 | cudaSetDevice(i); 32 | cudaMalloc(&csr_node_index_[i], (partition_count + 1) * sizeof(int64_t*)); 33 | cudaMalloc(&csr_dst_node_ids_[i], (partition_count + 1) * sizeof(int32_t*)); 34 | cudaMalloc(&d_global_count_[i], 4); 35 | h_global_count_[i] = (int32_t*)malloc(4); 36 | h_cache_hit_[i] = 0; 37 | find_iter_[i] = 0; 38 | h_batch_size_[i] = 0; 39 | } 40 | 41 | src_size_.resize(partition_count); 42 | dst_size_.resize(partition_count); 43 | cudaCheckError(); 44 | 45 | cudaSetDevice(0); 46 | 47 | int64_t* pin_csr_node_index; 48 | int32_t* pin_csr_dst_node_ids; 49 | 50 | h_csr_node_index_ = info->csr_node_index; 51 | h_csr_dst_node_ids_ = info->csr_dst_node_ids; 52 | 53 | // int64_t* d_csr_node_index; 54 | // int32_t* d_csr_dst_node_ids; 55 | // cudaMalloc(&d_csr_node_index, (node_num_ + 1) * sizeof(int64_t)); 56 | // cudaMalloc(&d_csr_dst_node_ids, edge_num_ * sizeof(int32_t)); 57 | // cudaMemcpy(d_csr_node_index, h_csr_node_index_, (node_num_ + 1) * sizeof(int64_t), cudaMemcpyHostToDevice); 58 | // cudaMemcpy(d_csr_dst_node_ids, h_csr_dst_node_ids_, edge_num_ * sizeof(int32_t), cudaMemcpyHostToDevice); 59 | 60 | cudaHostGetDevicePointer(&pin_csr_node_index, h_csr_node_index_, 0); 61 | cudaHostGetDevicePointer(&pin_csr_dst_node_ids, h_csr_dst_node_ids_, 0); 62 | assign_memory<<<1,1>>>(csr_dst_node_ids_[0], pin_csr_dst_node_ids, csr_node_index_[0], pin_csr_node_index, partition_count); 63 | cudaCheckError(); 64 | // assign_memory<<<1,1>>>(csr_dst_node_ids_[0], d_csr_dst_node_ids, csr_node_index_[0], d_csr_node_index, partition_count); 65 | // cudaCheckError(); 66 | // for(int i = 0; i < partition_count; i++){ 67 | // cudaMemcpy(csr_node_index_[i], csr_node_index_[0], (partition_count + 1) * sizeof(int64_t*), cudaMemcpyDeviceToDevice); 68 | // cudaMemcpy(csr_dst_node_ids_[i], csr_dst_node_ids_[0], (partition_count + 1) * sizeof(int32_t*), cudaMemcpyDeviceToDevice); 69 | // } 70 | csr_node_index_cpu_ = pin_csr_node_index; 71 | csr_dst_node_ids_cpu_ = pin_csr_dst_node_ids; 72 | 73 | } 74 | 75 | 76 | void GraphCache(int32_t* QT, int32_t Ki, int32_t Kg, int32_t capacity){ 77 | cudaMemcpy(csr_node_index_[Ki * Kg], csr_node_index_[0], (partition_count_ + 1) * sizeof(int64_t*), cudaMemcpyDeviceToDevice); 78 | cudaCheckError(); 79 | cudaMemcpy(csr_dst_node_ids_[Ki * Kg], csr_dst_node_ids_[0], (partition_count_ + 1) * sizeof(int32_t*), cudaMemcpyDeviceToDevice); 80 | cudaCheckError(); 81 | for(int32_t i = 0; i < Kg; i++){ 82 | cudaSetDevice(Ki * Kg + i); 83 | int64_t* neighbor_count; 84 | cudaMalloc(&neighbor_count, capacity * sizeof(int64_t)); 85 | GetNeighborCount<<<128, 1024>>>(QT, Kg, i, capacity, csr_node_index_cpu_, neighbor_count); 86 | 87 | int64_t* d_csr_node_index; 88 | cudaMalloc(&d_csr_node_index, (int64_t(capacity + 1)*sizeof(int64_t))); 89 | cudaMemset(d_csr_node_index, 0, (int64_t(capacity + 1)*sizeof(int64_t))); 90 | thrust::inclusive_scan(thrust::device, neighbor_count, neighbor_count + capacity, d_csr_node_index + 1); 91 | cudaCheckError(); 92 | int64_t* h_csr_node_index = (int64_t*)malloc((capacity + 1) * sizeof(int64_t)); 93 | cudaMemcpy(h_csr_node_index, d_csr_node_index, (capacity + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost); 94 | 95 | int32_t* d_csr_dst_node_ids; 96 | cudaMalloc(&d_csr_dst_node_ids, int64_t(int64_t(h_csr_node_index[capacity]) * sizeof(int32_t))); 97 | 98 | TopoFillUp<<<80, 1024>>>(QT, Kg, i, capacity, csr_node_index_cpu_, csr_dst_node_ids_cpu_, d_csr_node_index, d_csr_dst_node_ids); 99 | cudaCheckError(); 100 | 101 | assign_memory<<<1,1>>>(csr_dst_node_ids_[Ki * Kg], d_csr_dst_node_ids, csr_node_index_[Ki * Kg], d_csr_node_index, Ki * Kg + i); 102 | cudaCheckError(); 103 | cudaFree(neighbor_count); 104 | } 105 | for(int32_t i = 1; i < Kg; i++){ 106 | cudaMemcpy(csr_node_index_[Ki * Kg + i], csr_node_index_[Ki * Kg], (partition_count_ + 1) * sizeof(int64_t*), cudaMemcpyDeviceToDevice); 107 | cudaCheckError(); 108 | cudaMemcpy(csr_dst_node_ids_[Ki * Kg + i], csr_dst_node_ids_[Ki * Kg], (partition_count_ + 1) * sizeof(int32_t*), cudaMemcpyDeviceToDevice); 109 | cudaCheckError(); 110 | } 111 | } 112 | 113 | void Finalize() override { 114 | cudaFreeHost(csr_node_index_cpu_); 115 | cudaFreeHost(csr_dst_node_ids_cpu_); 116 | // for(int32_t i = 0; i < partition_count_; i++){ 117 | // cudaFree(partition_index_[i]); 118 | // cudaFree(partition_offset_[i]); 119 | // } 120 | } 121 | 122 | //CSR 123 | int32_t GetPartitionCount() const override { 124 | return partition_count_; 125 | } 126 | int64_t** GetCSRNodeIndex(int32_t part_id) const override { 127 | return csr_node_index_[part_id]; 128 | } 129 | int32_t** GetCSRNodeMatrix(int32_t part_id) const override { 130 | return csr_dst_node_ids_[part_id]; 131 | } 132 | 133 | int64_t* GetCSRNodeIndexCPU() const override { 134 | return csr_node_index_cpu_; 135 | } 136 | 137 | int32_t* GetCSRNodeMatrixCPU() const override { 138 | return csr_dst_node_ids_cpu_; 139 | } 140 | 141 | int64_t Src_Size(int32_t part_id) const override { 142 | return src_size_[part_id]; 143 | } 144 | int64_t Dst_Size(int32_t part_id) const override { 145 | return dst_size_[part_id]; 146 | } 147 | char* PartitionIndex(int32_t part_id) const override { 148 | return partition_index_[part_id]; 149 | } 150 | int32_t* PartitionOffset(int32_t part_id) const override { 151 | return partition_offset_[part_id]; 152 | } 153 | 154 | private: 155 | std::vector src_size_; 156 | std::vector dst_size_; 157 | 158 | int32_t node_num_; 159 | int64_t edge_num_; 160 | int64_t cache_edge_num_; 161 | 162 | //CSR graph, every partition has a ptr copy 163 | int32_t partition_count_; 164 | std::vector csr_node_index_; 165 | std::vector csr_dst_node_ids_; 166 | int64_t* csr_node_index_cpu_; 167 | int32_t* csr_dst_node_ids_cpu_; 168 | 169 | int64_t* h_csr_node_index_; 170 | int32_t* h_csr_dst_node_ids_; 171 | 172 | std::vector partition_index_; 173 | std::vector partition_offset_; 174 | 175 | std::vector h_global_count_; 176 | std::vector d_global_count_; 177 | 178 | 179 | std::vector find_iter_; 180 | std::vector h_cache_hit_; 181 | std::vector h_batch_size_; 182 | }; 183 | 184 | extern "C" 185 | GraphStorage* NewCompleteGraphStorage(){ 186 | CompleteGraphStorage* ret = new CompleteGraphStorage(); 187 | return ret; 188 | } 189 | -------------------------------------------------------------------------------- /sampling_server/src/storage/graph_storage.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef GRAPH_STORAGE_H_ 3 | #define GRAPH_STORAGE_H_ 4 | 5 | #include "buildinfo.h" 6 | 7 | class GraphStorage { 8 | public: 9 | virtual ~GraphStorage() = default; 10 | //build 11 | virtual void Build(BuildInfo* info) = 0; 12 | virtual void GraphCache(int32_t* QT, int32_t Ki, int32_t Kg, int32_t capacity) = 0; 13 | virtual void Finalize() = 0; 14 | //CSR 15 | virtual int32_t GetPartitionCount() const = 0; 16 | virtual int64_t** GetCSRNodeIndex(int32_t part_id) const = 0; 17 | virtual int32_t** GetCSRNodeMatrix(int32_t part_id) const = 0; 18 | virtual int64_t* GetCSRNodeIndexCPU() const = 0; 19 | virtual int32_t* GetCSRNodeMatrixCPU() const = 0; 20 | virtual int64_t Src_Size(int32_t part_id) const = 0; 21 | virtual int64_t Dst_Size(int32_t part_id) const = 0; 22 | virtual char* PartitionIndex(int32_t part_id) const = 0; 23 | virtual int32_t* PartitionOffset(int32_t part_id) const = 0; 24 | }; 25 | extern "C" 26 | GraphStorage* NewCompleteGraphStorage(); 27 | 28 | #endif // GRAPH_STORAGE_H_ -------------------------------------------------------------------------------- /sampling_server/src/storage/graph_storage_impl.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef GRAPH_STORAGE_IMPL_H_ 3 | #define GRAPH_STORAGE_IMPL_H_ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | // Macro for checking cuda errors following a cuda launch or api call 16 | #define cudaCheckError() \ 17 | { \ 18 | cudaError_t e = cudaGetLastError(); \ 19 | if (e != cudaSuccess) { \ 20 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 21 | cudaGetErrorString(e)); \ 22 | exit(EXIT_FAILURE); \ 23 | } \ 24 | } 25 | 26 | 27 | __global__ void assign_memory(int32_t** int32_pptr, int32_t* int32_ptr, int64_t** int64_pptr, int64_t* int64_ptr, int32_t device_id){ 28 | int32_pptr[device_id] = int32_ptr; 29 | int64_pptr[device_id] = int64_ptr; 30 | } 31 | 32 | 33 | __global__ void GetNeighborCount(int32_t* QT, int32_t Kg, int32_t Ki, int32_t capacity, int64_t* csr_node_index_cpu, int64_t* neighbor_count){ 34 | for(int32_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; thread_idx < capacity; thread_idx += gridDim.x * blockDim.x){ 35 | int32_t cache_id = QT[thread_idx * Kg + Ki]; 36 | int64_t count = csr_node_index_cpu[cache_id + 1] - csr_node_index_cpu[cache_id]; 37 | neighbor_count[thread_idx] = count; 38 | } 39 | } 40 | 41 | __global__ void TopoFillUp(int32_t* QT, int32_t Kg, int32_t Ki, int32_t capacity, 42 | int64_t* csr_node_index_cpu, int32_t* csr_dst_node_ids_cpu, 43 | int64_t* d_csr_node_index, int32_t* d_csr_dst_node_ids){ 44 | for(int32_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; thread_idx < capacity; thread_idx += gridDim.x * blockDim.x){ 45 | int32_t cache_id = QT[thread_idx * Kg + Ki]; 46 | int64_t count = csr_node_index_cpu[cache_id + 1] - csr_node_index_cpu[cache_id]; 47 | for(int i = 0; i < count; i++){ 48 | int32_t neighbor_id = csr_dst_node_ids_cpu[csr_node_index_cpu[cache_id] + i]; 49 | int64_t start_off = d_csr_node_index[thread_idx]; 50 | d_csr_dst_node_ids[start_off + i] = neighbor_id; 51 | } 52 | } 53 | } 54 | 55 | 56 | 57 | #endif -------------------------------------------------------------------------------- /sampling_server/src/storage/storage_management.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include "graph_storage.cuh" 3 | #include "feature_storage.cuh" 4 | #include "cache.cuh" 5 | #include "ipc_service.h" 6 | 7 | class StorageManagement { 8 | public: 9 | 10 | void Initialze(int32_t partition_count, int32_t in_memory_mode); 11 | 12 | GraphStorage* GetGraph(); 13 | 14 | FeatureStorage* GetFeature(); 15 | 16 | UnifiedCache* GetCache(); 17 | 18 | IPCEnv* GetIPCEnv(); 19 | 20 | int32_t Shard_To_Device(int32_t part_id); 21 | 22 | int32_t Shard_To_Partition(int32_t part_id); 23 | 24 | int32_t Central_Device(); 25 | 26 | private: 27 | void EnableP2PAccess(); 28 | 29 | void ConfigPartition(BuildInfo* info, int32_t partition_count); 30 | 31 | void ReadMetaFIle(BuildInfo* info); 32 | 33 | void LoadGraph(BuildInfo* info); 34 | 35 | void LoadFeature(BuildInfo* info); 36 | 37 | int32_t in_memory_mode_; 38 | int32_t partition_; 39 | 40 | int64_t cache_edge_num_; 41 | int64_t edge_num_; 42 | int32_t node_num_; 43 | 44 | int32_t training_set_num_; 45 | int32_t validation_set_num_; 46 | int32_t testing_set_num_; 47 | 48 | int32_t float_feature_len_; 49 | 50 | int64_t cache_memory_; 51 | 52 | std::string dataset_path_; 53 | int32_t raw_batch_size_; 54 | int32_t epoch_; 55 | int32_t num_ssd_; 56 | int32_t num_queues_per_ssd_; 57 | int32_t cpu_cache_capacity_;//for Helios 58 | int32_t gpu_cache_capacity_;//for Helios 59 | 60 | GraphStorage* graph_; 61 | FeatureStorage* feature_; 62 | UnifiedCache* cache_; 63 | IPCEnv* env_; 64 | }; 65 | 66 | 67 | -------------------------------------------------------------------------------- /sampling_server/src/storage/storage_management_impl.cuh: -------------------------------------------------------------------------------- 1 | #ifndef STORAGE_MANAGEMENT_IMPL_H_ 2 | #define STORAGE_MANAGEMENT_IMPL_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | 33 | 34 | // Macro for checking cuda errors following a cuda launch or api call 35 | #define cudaCheckError() \ 36 | { \ 37 | cudaError_t e = cudaGetLastError(); \ 38 | if (e != cudaSuccess) { \ 39 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 40 | cudaGetErrorString(e)); \ 41 | exit(EXIT_FAILURE); \ 42 | } \ 43 | } 44 | 45 | 46 | void mmap_trainingset_read(std::string &training_file, std::vector& training_set_ids){ 47 | int64_t t_idx = 0; 48 | int32_t fd = open(training_file.c_str(), O_RDONLY); 49 | if(fd == -1){ 50 | std::cout<<"cannout open file: "<& labels){ 143 | int64_t n_idx = 0; 144 | int32_t fd = open(labels_file.c_str(), O_RDONLY); 145 | if(fd == -1){ 146 | std::cout<<"cannout open file: "< 3 | #include 4 | 5 | int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info) 6 | { 7 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 8 | info->size = sz; 9 | info->shmHandle = CreateFileMapping(INVALID_HANDLE_VALUE, 10 | NULL, 11 | PAGE_READWRITE, 12 | 0, 13 | (DWORD)sz, 14 | name); 15 | if (info->shmHandle == 0) { 16 | return GetLastError(); 17 | } 18 | 19 | info->addr = MapViewOfFile(info->shmHandle, FILE_MAP_ALL_ACCESS, 0, 0, sz); 20 | if (info->addr == NULL) { 21 | return GetLastError(); 22 | } 23 | 24 | return 0; 25 | #else 26 | int status = 0; 27 | 28 | info->size = sz; 29 | 30 | info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777); 31 | if (info->shmFd < 0) { 32 | return errno; 33 | } 34 | 35 | status = ftruncate(info->shmFd, sz); 36 | if (status != 0) { 37 | return status; 38 | } 39 | 40 | info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); 41 | if (info->addr == NULL) { 42 | return errno; 43 | } 44 | 45 | return 0; 46 | #endif 47 | } 48 | 49 | int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info) 50 | { 51 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 52 | info->size = sz; 53 | 54 | info->shmHandle = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name); 55 | if (info->shmHandle == 0) { 56 | return GetLastError(); 57 | } 58 | 59 | info->addr = MapViewOfFile(info->shmHandle, FILE_MAP_ALL_ACCESS, 0, 0, sz); 60 | if (info->addr == NULL) { 61 | return GetLastError(); 62 | } 63 | 64 | return 0; 65 | #else 66 | info->size = sz; 67 | 68 | info->shmFd = shm_open(name, O_RDWR, 0777); 69 | if (info->shmFd < 0) { 70 | return errno; 71 | } 72 | 73 | info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); 74 | if (info->addr == NULL) { 75 | return errno; 76 | } 77 | 78 | return 0; 79 | #endif 80 | } 81 | 82 | void sharedMemoryClose(sharedMemoryInfo *info) 83 | { 84 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 85 | if (info->addr) { 86 | UnmapViewOfFile(info->addr); 87 | } 88 | if (info->shmHandle) { 89 | CloseHandle(info->shmHandle); 90 | } 91 | #else 92 | if (info->addr) { 93 | munmap(info->addr, info->size); 94 | } 95 | if (info->shmFd) { 96 | close(info->shmFd); 97 | } 98 | #endif 99 | } 100 | 101 | int spawnProcess(Process *process, const char *app, char * const *args) 102 | { 103 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 104 | STARTUPINFO si = {0}; 105 | BOOL status; 106 | size_t arglen = 0; 107 | size_t argIdx = 0; 108 | std::string arg_string; 109 | memset(process, 0, sizeof(*process)); 110 | 111 | while (*args) { 112 | arg_string.append(*args).append(1, ' '); 113 | args++; 114 | } 115 | 116 | status = CreateProcess(app, LPSTR(arg_string.c_str()), NULL, NULL, FALSE, 0, NULL, NULL, &si, process); 117 | 118 | return status ? 0 : GetLastError(); 119 | #else 120 | *process = fork(); 121 | if (*process == 0) { 122 | if (0 > execvp(app, args)) { 123 | return errno; 124 | } 125 | } 126 | else if (*process < 0) { 127 | return errno; 128 | } 129 | return 0; 130 | #endif 131 | } 132 | 133 | int waitProcess(Process *process) 134 | { 135 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 136 | DWORD exitCode; 137 | WaitForSingleObject(process->hProcess, INFINITE); 138 | GetExitCodeProcess(process->hProcess, &exitCode); 139 | CloseHandle(process->hProcess); 140 | CloseHandle(process->hThread); 141 | return (int)exitCode; 142 | #else 143 | int status = 0; 144 | do { 145 | if (0 > waitpid(*process, &status, 0)) { 146 | return errno; 147 | } 148 | } while (!WIFEXITED(status)); 149 | return WEXITSTATUS(status); 150 | #endif 151 | } 152 | -------------------------------------------------------------------------------- /training_backend/helper_multiprocess.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-2018 NVIDIA Corporation. All rights reserved. 3 | * 4 | * Please refer to the NVIDIA end user license agreement (EULA) associated 5 | * with this source code for terms and conditions that govern your use of 6 | * this software. Any use, reproduction, disclosure, or distribution of 7 | * this software and related documentation outside the terms of the EULA 8 | * is strictly prohibited. 9 | * 10 | */ 11 | 12 | #ifndef HELPER_MULTIPROCESS_H 13 | #define HELPER_MULTIPROCESS_H 14 | 15 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 16 | #ifndef WIN32_LEAN_AND_MEAN 17 | #define WIN32_LEAN_AND_MEAN 18 | #endif 19 | #include 20 | #else 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #endif 27 | 28 | typedef struct sharedMemoryInfo_st { 29 | void *addr; 30 | size_t size; 31 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 32 | HANDLE shmHandle; 33 | #else 34 | int shmFd; 35 | #endif 36 | } sharedMemoryInfo; 37 | 38 | int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info); 39 | 40 | int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info); 41 | 42 | void sharedMemoryClose(sharedMemoryInfo *info); 43 | 44 | 45 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 46 | typedef PROCESS_INFORMATION Process; 47 | #else 48 | typedef pid_t Process; 49 | #endif 50 | 51 | int spawnProcess(Process *process, const char *app, char * const *args); 52 | 53 | int waitProcess(Process *process); 54 | 55 | #endif // HELPER_MULTIPROCESS_H 56 | -------------------------------------------------------------------------------- /training_backend/ipc_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "helper_multiprocess.h" 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include "ipc_service.h" 12 | 13 | #include 14 | #include 15 | 16 | #define MAX_DEVICE 8 17 | #define MEMORY_USAGE 7 18 | 19 | // Macro for checking cuda errors following a cuda launch or api call 20 | #define cudaCheckError() \ 21 | { \ 22 | cudaError_t e = cudaGetLastError(); \ 23 | if (e != cudaSuccess) { \ 24 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 25 | cudaGetErrorString(e)); \ 26 | exit(EXIT_FAILURE); \ 27 | } \ 28 | } 29 | 30 | typedef struct shmStruct_st { 31 | int32_t steps[3]; 32 | cudaIpcMemHandle_t memHandle[MAX_DEVICE][INTERBATCH_CON][MEMORY_USAGE]; 33 | } shmStruct; 34 | 35 | class GPUIPCEnv : public IPCEnv { 36 | public: 37 | int Initialize() override { 38 | volatile shmStruct *shm = NULL; 39 | int central_device = -1; 40 | cudaGetDevice(¢ral_device); 41 | cudaCheckError(); 42 | sharedMemoryInfo info; 43 | const char shmName[] = "simpleIPCshm"; 44 | if (sharedMemoryCreate(shmName, sizeof(*shm), &info) != 0) { 45 | printf("Failed to create shared memory slab\n"); 46 | exit(EXIT_FAILURE); 47 | } 48 | 49 | shm = (volatile shmStruct *)info.addr; 50 | train_step_ = shm->steps[0]; 51 | valid_step_ = shm->steps[1]; 52 | test_step_ = shm->steps[2]; 53 | ids_.resize(INTERBATCH_CON); 54 | float_features_.resize(INTERBATCH_CON); 55 | labels_.resize(INTERBATCH_CON); 56 | agg_src_.resize(INTERBATCH_CON); 57 | agg_dst_.resize(INTERBATCH_CON); 58 | node_counter_.resize(INTERBATCH_CON); 59 | edge_counter_.resize(INTERBATCH_CON); 60 | 61 | for(int i = 0; i < INTERBATCH_CON; i++){ 62 | cudaIpcOpenMemHandle(&ids_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][0], cudaIpcMemLazyEnablePeerAccess); 63 | cudaIpcOpenMemHandle(&float_features_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][1], cudaIpcMemLazyEnablePeerAccess); 64 | cudaIpcOpenMemHandle(&labels_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][2], cudaIpcMemLazyEnablePeerAccess); 65 | cudaIpcOpenMemHandle(&agg_src_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][3], cudaIpcMemLazyEnablePeerAccess); 66 | cudaIpcOpenMemHandle(&agg_dst_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][4], cudaIpcMemLazyEnablePeerAccess); 67 | cudaIpcOpenMemHandle(&node_counter_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][5], cudaIpcMemLazyEnablePeerAccess); 68 | cudaIpcOpenMemHandle(&edge_counter_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][6], cudaIpcMemLazyEnablePeerAccess); 69 | cudaCheckError(); 70 | } 71 | std::cout<<"CUDA: "< ids_; 157 | std::vector float_features_; 158 | std::vector labels_; 159 | std::vector agg_src_; 160 | std::vector agg_dst_; 161 | std::vector node_counter_; 162 | std::vector edge_counter_; 163 | std::vector semw_; 164 | std::vector semr_; 165 | 166 | int32_t train_step_; 167 | int32_t valid_step_; 168 | int32_t test_step_; 169 | int current_pipe_; 170 | }; 171 | IPCEnv* NewIPCEnv(){ 172 | return new GPUIPCEnv(); 173 | } 174 | 175 | // Define the GPU implementation that launches the CUDA kernel. 176 | 177 | std::vector cuda_get_next( 178 | int32_t* ids, 179 | float* float_features, 180 | int32_t* labels, 181 | int feature_dim, 182 | int32_t* agg_src, 183 | int32_t* agg_dst, 184 | int32_t* node_counter, 185 | int32_t* edge_counter, 186 | int32_t* h_node_counter, 187 | int32_t* h_edge_counter 188 | ){ 189 | int current_dev = -1; 190 | cudaGetDevice(¤t_dev); 191 | auto device = "cuda:" + std::to_string(current_dev); 192 | cudaCheckError(); 193 | 194 | cudaMemcpy(h_node_counter, node_counter, 16 * sizeof(int32_t), cudaMemcpyDeviceToHost); 195 | cudaMemcpy(h_edge_counter, edge_counter, 16 * sizeof(int32_t), cudaMemcpyDeviceToHost); 196 | int hop_num = h_node_counter[INTRABATCH_CON * 3 - 1]; 197 | 198 | std::vector ret; 199 | 200 | torch::Tensor ids_tensor = torch::from_blob( 201 | ids, 202 | {(long long)h_node_counter[INTRABATCH_CON * 3 + hop_num]}, 203 | torch::TensorOptions().dtype(torch::kI32).device(device)); 204 | 205 | ret.push_back(ids_tensor); 206 | 207 | torch::Tensor feature_tensor = torch::from_blob( 208 | float_features, 209 | {(long long)(h_node_counter[INTRABATCH_CON * 3 + hop_num]), (long long)(feature_dim)}, 210 | torch::TensorOptions().dtype(torch::kF32).device(device)); 211 | 212 | ret.push_back(feature_tensor); 213 | 214 | torch::Tensor labels_tensor = torch::from_blob( 215 | labels, 216 | {(long long)h_node_counter[INTRABATCH_CON * 3]}, 217 | torch::TensorOptions().dtype(torch::kI32).device(device)); 218 | 219 | ret.push_back(labels_tensor); 220 | 221 | for(int i = hop_num; i > 0; i--){ 222 | torch::Tensor agg_src_tensor = torch::from_blob( 223 | agg_src, 224 | {(long long)h_edge_counter[INTRABATCH_CON * 3 + i]}, 225 | torch::TensorOptions().dtype(torch::kI32).device(device)); 226 | torch::Tensor agg_dst_tensor = torch::from_blob( 227 | agg_dst, 228 | {(long long)h_edge_counter[INTRABATCH_CON * 3 + i]}, 229 | torch::TensorOptions().dtype(torch::kI32).device(device)); 230 | ret.push_back(agg_src_tensor); 231 | ret.push_back(agg_dst_tensor); 232 | } 233 | 234 | return ret; 235 | } 236 | -------------------------------------------------------------------------------- /training_backend/ipc_service.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "ipc_service.h" 9 | 10 | 11 | IPCEnv* env; 12 | int32_t h_node_counter[16]; 13 | int32_t h_edge_counter[16]; 14 | 15 | void InitializeIPC(){ 16 | env = NewIPCEnv(); 17 | env->Initialize(); 18 | } 19 | 20 | void FinalizeIPC(){ 21 | env->Finalize(); 22 | } 23 | 24 | std::vector cuda_get_next( 25 | int32_t* ids, 26 | float* float_features, 27 | int32_t* labels, 28 | int feature_dim, 29 | int32_t* agg_src, 30 | int32_t* agg_dst, 31 | int32_t* node_counter, 32 | int32_t* edge_counter, 33 | int32_t* h_node_counter, 34 | int32_t* h_edge_counter 35 | ); 36 | 37 | // C++ interface 38 | 39 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 40 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 41 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 42 | 43 | 44 | std::vector get_next(int feature_dim) { 45 | env->Wait(); 46 | int32_t* ids = env->GetIds(); 47 | float* float_features = env->GetFloatFeatures(); 48 | int32_t* labels = env->GetLabels(); 49 | int32_t* agg_src = env->GetAggSrc(); 50 | int32_t* agg_dst = env->GetAggDst(); 51 | int32_t* node_counter = env->GetNodeCounter(); 52 | int32_t* edge_counter = env->GetEdgeCounter(); 53 | auto result = cuda_get_next(ids, float_features, labels, 54 | feature_dim, 55 | agg_src, agg_dst, 56 | node_counter, edge_counter, 57 | h_node_counter, h_edge_counter); 58 | return result; 59 | } 60 | 61 | std::vector get_block_size() { 62 | std::vector ret; 63 | int hop_num = h_node_counter[INTRABATCH_CON * 3 - 1]; 64 | 65 | for(int i = hop_num; i > 0; i--){ 66 | ret.push_back(h_node_counter[INTRABATCH_CON * 3 + i]); 67 | ret.push_back(h_node_counter[INTRABATCH_CON * 3 + i - 1]); 68 | } 69 | // int block1_src_node = h_node_counter[9]; 70 | // int block1_dst_node = h_node_counter[7]; 71 | // int block2_src_node = h_node_counter[7]; 72 | // int block2_dst_node = h_node_counter[5]; 73 | 74 | // ret.push_back(block1_src_node); 75 | // ret.push_back(block1_dst_node); 76 | // ret.push_back(block2_src_node); 77 | // ret.push_back(block2_dst_node); 78 | return ret; 79 | } 80 | 81 | std::vector get_steps(){ 82 | std::vector ret; 83 | ret.push_back(env->GetTrainStep()); 84 | ret.push_back(env->GetValidStep()); 85 | ret.push_back(env->GetTestStep()); 86 | return ret; 87 | } 88 | 89 | void Synchronize(){ 90 | env->Post(); 91 | } 92 | 93 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 94 | m.def("get_next", &get_next, "dataset get next (CUDA)"); 95 | m.def("get_block_size", &get_block_size, "get dgl block size(CUDA)"); 96 | m.def("get_steps", &get_steps, "get steps(CUDA)"); 97 | m.def("initialize", &InitializeIPC, "InitializeIPC (CUDA)"); 98 | m.def("finalize", &FinalizeIPC, "FinalizeIPC (CUDA)"); 99 | m.def("synchronize", &Synchronize, "synchronize (CUDA)"); 100 | } -------------------------------------------------------------------------------- /training_backend/ipc_service.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define INTRABATCH_CON 3 4 | #define INTERBATCH_CON 2 5 | 6 | class IPCEnv { 7 | public: 8 | virtual int Initialize() = 0; 9 | virtual int32_t* GetIds() = 0; 10 | virtual float* GetFloatFeatures() = 0; 11 | virtual int32_t* GetLabels() = 0; 12 | virtual int32_t* GetAggSrc() = 0; 13 | virtual int32_t* GetAggDst() = 0; 14 | virtual int32_t* GetNodeCounter() = 0; 15 | virtual int32_t* GetEdgeCounter() = 0; 16 | virtual int32_t GetTrainStep() = 0; 17 | virtual int32_t GetValidStep() = 0; 18 | virtual int32_t GetTestStep() = 0; 19 | virtual void Finalize() = 0; 20 | virtual void Wait() = 0; 21 | virtual void Post() = 0; 22 | }; 23 | IPCEnv* NewIPCEnv(); -------------------------------------------------------------------------------- /training_backend/legion_gat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0" 4 | import sys 5 | import tempfile 6 | import argparse 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.multiprocessing as mp 12 | 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | import torch.nn.functional as Func 15 | 16 | import ipc_service 17 | import dgl 18 | from dgl.nn.pytorch import GATConv 19 | from dgl.heterograph import DGLBlock 20 | import time 21 | import numpy as np 22 | import torchmetrics 23 | torch.set_printoptions(threshold=np.inf) 24 | 25 | def setup(rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = '12355' 28 | # initialize the process group 29 | if torch.cuda.is_available(): 30 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 31 | else: 32 | dist.init_process_group('gloo', rank=rank, world_size=world_size) 33 | 34 | def cleanup(): 35 | dist.destroy_process_group() 36 | 37 | class GAT(nn.Module): 38 | def __init__(self, in_feats, n_hidden, n_classes, heads, n_layers, activation, dropout): 39 | super().__init__() 40 | self.n_layers = n_layers 41 | self.n_hidden = n_hidden 42 | self.n_classes = n_classes 43 | self.gat_layers = nn.ModuleList() 44 | self.gat_layers.append( 45 | GATConv( 46 | in_feats, 47 | n_hidden, 48 | heads[0], 49 | feat_drop=0.6, 50 | attn_drop=0.6, 51 | activation=activation, 52 | allow_zero_in_degree=True, 53 | ) 54 | ) 55 | self.gat_layers.append( 56 | GATConv( 57 | n_hidden * heads[0], 58 | n_classes, 59 | heads[1], 60 | feat_drop=0.6, 61 | attn_drop=0.6, 62 | allow_zero_in_degree=True, 63 | ) 64 | ) 65 | self.dropout = nn.Dropout(dropout) 66 | self.activation = activation 67 | 68 | def forward(self, blocks, x): 69 | h = x 70 | for l, (layer, block) in enumerate(zip(self.gat_layers, blocks)): 71 | h = layer(block, h) 72 | if l != len(self.gat_layers) - 1: 73 | h = self.activation(h) 74 | h = self.dropout(h) 75 | if l == 1: # last layer 76 | h = h.mean(1) 77 | else: # other layer(s) 78 | h = h.flatten(1) 79 | return h 80 | 81 | def create_dgl_block(src, dst, num_src_nodes, num_dst_nodes): 82 | gidx = dgl.heterograph_index.create_unitgraph_from_coo(2, num_src_nodes, num_dst_nodes, src, dst, 'coo', row_sorted=True) 83 | g = DGLBlock(gidx, (['_N'], ['_N']), ['_E']) 84 | 85 | return g 86 | 87 | def train_one_step(model, optimizer, loss_fcn, device, feat_len, iter, device_id): 88 | 89 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 90 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 91 | 92 | blocks = [] 93 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 94 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 95 | # print(features[:100]) 96 | # print(ids[:100]) 97 | batch_pred = model(blocks, features) 98 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 99 | loss = loss_fcn(batch_pred, long_labels) 100 | optimizer.zero_grad() 101 | loss.backward() 102 | optimizer.step() 103 | 104 | torch.cuda.synchronize() 105 | ipc_service.synchronize() 106 | return 0 107 | 108 | def valid_one_step(model, metric, device, feat_len): 109 | 110 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 111 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 112 | blocks = [] 113 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 114 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 115 | batch_pred = model(blocks, features) 116 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 117 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 118 | acc = metric(batch_pred, long_labels) 119 | ipc_service.synchronize() 120 | return acc 121 | 122 | def test_one_step(model, metric, device, feat_len): 123 | 124 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 125 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 126 | blocks = [] 127 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 128 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 129 | batch_pred = model(blocks, features) 130 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 131 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 132 | acc = metric(batch_pred, long_labels) 133 | ipc_service.synchronize() 134 | return acc 135 | 136 | def worker_process(rank, world_size, args): 137 | print(f"Running GNN Training on CUDA {rank}.") 138 | device_id = rank 139 | setup(rank, world_size) 140 | cuda_device = torch.device("cuda:{}".format(device_id)) 141 | torch.cuda.set_device(cuda_device) 142 | ipc_service.initialize() 143 | train_steps, valid_steps, test_steps = ipc_service.get_steps() 144 | batch_size = (args.train_batch_size) 145 | hop1 = (args.nbrs_num)[0] 146 | hop2 = (args.nbrs_num)[1] 147 | 148 | feat_len = args.features_num 149 | 150 | heads = [8,1] 151 | model = GAT(in_feats=args.features_num, 152 | n_hidden=args.hidden_dim, 153 | n_classes=args.class_num, 154 | heads=[8,1], 155 | n_layers=args.hops_num, 156 | activation=Func.relu, 157 | dropout=args.drop_rate).to(cuda_device) 158 | 159 | if dist.is_initialized(): 160 | model = DDP(model, device_ids=[device_id]) 161 | loss_fcn = nn.CrossEntropyLoss() 162 | loss_fcn = loss_fcn.to(device_id) 163 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 164 | model.train() 165 | 166 | epoch_num = args.epoch 167 | 168 | for epoch in range(epoch_num): 169 | forward = 0 170 | start = time.time() 171 | epoch_time = 0 172 | for iter in range(train_steps): 173 | train_loss = train_one_step(model, optimizer, loss_fcn, cuda_device, feat_len, iter, device_id) 174 | # if device_id == 0: 175 | # print('Iter {} Train Loss :{} '.format(iter, train_loss)) 176 | epoch_time += time.time() - start 177 | 178 | model.eval() 179 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 180 | metric = metric.to(device_id) 181 | model.metric = metric 182 | with torch.no_grad(): 183 | for iter in range(valid_steps): 184 | valid_one_step(model, metric, cuda_device, feat_len) 185 | acc_val = metric.compute() 186 | if device_id == 0: 187 | print("Epoch:{}, Cost:{} s, Val Acc: {}".format(epoch, epoch_time, acc_val)) 188 | 189 | 190 | model.eval() 191 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 192 | metric = metric.to(device_id) 193 | model.metric = metric 194 | with torch.no_grad(): 195 | for iter in range(test_steps): 196 | test_one_step(model, metric, cuda_device, feat_len) 197 | acc = metric.compute() 198 | if device_id == 0: 199 | print("Accuracy on test data: {}".format(acc)) 200 | metric.reset() 201 | 202 | ipc_service.finalize() 203 | cleanup() 204 | 205 | def run_distribute(dist_fn, world_size, args): 206 | mp.spawn(dist_fn, 207 | args=(world_size, args), 208 | nprocs=world_size, 209 | join=True) 210 | 211 | if __name__ == "__main__": 212 | cur_path = sys.path[0] 213 | argparser = argparse.ArgumentParser("Train GNN.") 214 | argparser.add_argument('--class_num', type=int, default=2) 215 | argparser.add_argument('--features_num', type=int, default=128) 216 | argparser.add_argument('--hidden_dim', type=int, default=256) 217 | argparser.add_argument('--hops_num', type=int, default=2) 218 | argparser.add_argument('--nbrs_num', type=list, default=[25, 10]) 219 | argparser.add_argument('--drop_rate', type=float, default=0.5) 220 | argparser.add_argument('--learning_rate', type=float, default=0.003) 221 | argparser.add_argument('--epoch', type=int, default=2) 222 | argparser.add_argument('--gpu_number', type=int, default=2) 223 | args = argparser.parse_args() 224 | 225 | world_size = args.gpu_num 226 | 227 | run_distribute(worker_process, world_size, args) 228 | -------------------------------------------------------------------------------- /training_backend/legion_gcn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0" 4 | import sys 5 | import tempfile 6 | import argparse 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.multiprocessing as mp 12 | 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | import torch.nn.functional as Func 15 | 16 | import ipc_service 17 | import dgl 18 | from dgl.nn.pytorch import SAGEConv 19 | from dgl.nn.pytorch import GraphConv 20 | from dgl.heterograph import DGLBlock 21 | import time 22 | import numpy as np 23 | import torchmetrics 24 | torch.set_printoptions(threshold=np.inf) 25 | 26 | def setup(rank, world_size): 27 | os.environ['MASTER_ADDR'] = 'localhost' 28 | os.environ['MASTER_PORT'] = '12355' 29 | # initialize the process group 30 | if torch.cuda.is_available(): 31 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 32 | else: 33 | dist.init_process_group('gloo', rank=rank, world_size=world_size) 34 | 35 | def cleanup(): 36 | dist.destroy_process_group() 37 | 38 | class SAGE(nn.Module): 39 | def __init__(self, 40 | in_feats, 41 | n_hidden, 42 | n_classes, 43 | n_layers, 44 | activation, 45 | dropout): 46 | super().__init__() 47 | self.n_layers = n_layers 48 | self.n_hidden = n_hidden 49 | self.n_classes = n_classes 50 | self.layers = nn.ModuleList() 51 | self.layers.append(SAGEConv(in_feats, n_hidden, 'mean')) 52 | for _ in range(1, n_layers - 1): 53 | self.layers.append(SAGEConv(n_hidden, n_hidden, 'mean')) 54 | self.layers.append(SAGEConv(n_hidden, n_classes, 'mean')) 55 | self.dropout = nn.Dropout(dropout) 56 | self.activation = activation 57 | 58 | def forward(self, blocks, x): 59 | h = x 60 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): 61 | h = layer(block, h) 62 | if l != len(self.layers) - 1: 63 | h = self.activation(h) 64 | h = self.dropout(h) 65 | return h 66 | 67 | 68 | class GCN(nn.Module): 69 | def __init__(self, 70 | in_feats, 71 | n_hidden, 72 | n_classes, 73 | n_layers, 74 | activation, 75 | dropout): 76 | super(GCN, self).__init__() 77 | self.layers = nn.ModuleList() 78 | # input layer 79 | self.layers.append( 80 | GraphConv(in_feats, n_hidden, activation=activation, allow_zero_in_degree=True)) 81 | # hidden layers 82 | for _ in range(n_layers - 2): 83 | self.layers.append( 84 | GraphConv(n_hidden, n_hidden, activation=activation, allow_zero_in_degree=True)) 85 | # output layer 86 | self.layers.append( 87 | GraphConv(n_hidden, n_classes, allow_zero_in_degree=True)) 88 | self.dropout = nn.Dropout(p=dropout) 89 | 90 | def forward(self, blocks, features): 91 | h = features 92 | for i, layer in enumerate(self.layers): 93 | if i != 0: 94 | h = self.dropout(h) 95 | h = layer(blocks[i], h) 96 | return h 97 | 98 | def create_dgl_block(src, dst, num_src_nodes, num_dst_nodes): 99 | gidx = dgl.heterograph_index.create_unitgraph_from_coo(2, num_src_nodes, num_dst_nodes, src, dst, 'coo', row_sorted=True) 100 | g = DGLBlock(gidx, (['_N'], ['_N']), ['_E']) 101 | 102 | return g 103 | 104 | def train_one_step(model, optimizer, loss_fcn, device, feat_len, iter, device_id): 105 | 106 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 107 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 108 | blocks = [] 109 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 110 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 111 | 112 | batch_pred = model(blocks, features) 113 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 114 | loss = loss_fcn(batch_pred, long_labels) 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | torch.cuda.synchronize() 119 | ipc_service.synchronize() 120 | return loss 121 | 122 | def valid_one_step(model, metric, device, feat_len): 123 | 124 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 125 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 126 | blocks = [] 127 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 128 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 129 | batch_pred = model(blocks, features) 130 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 131 | 132 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 133 | acc = metric(batch_pred, long_labels) 134 | ipc_service.synchronize() 135 | return acc 136 | 137 | def test_one_step(model, metric, device, feat_len): 138 | 139 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 140 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 141 | blocks = [] 142 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 143 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 144 | batch_pred = model(blocks, features) 145 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 146 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 147 | acc = metric(batch_pred, long_labels) 148 | ipc_service.synchronize() 149 | return acc 150 | 151 | def worker_process(rank, world_size, args): 152 | print(f"Running GNN Training on CUDA {rank}.") 153 | device_id = rank 154 | setup(rank, world_size) 155 | cuda_device = torch.device("cuda:{}".format(device_id)) 156 | torch.cuda.set_device(cuda_device) 157 | ipc_service.initialize() 158 | train_steps, valid_steps, test_steps = ipc_service.get_steps() 159 | 160 | feat_len = args.features_num 161 | 162 | model = GCN(in_feats=args.features_num, 163 | n_hidden=args.hidden_dim, 164 | n_classes=args.class_num, 165 | n_layers=args.hops_num, 166 | activation=Func.relu, 167 | dropout=args.drop_rate).to(cuda_device) 168 | 169 | if dist.is_initialized(): 170 | model = DDP(model, device_ids=[device_id]) 171 | loss_fcn = nn.CrossEntropyLoss() 172 | loss_fcn = loss_fcn.to(device_id) 173 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 174 | model.train() 175 | 176 | epoch_num = args.epoch 177 | 178 | for epoch in range(epoch_num): 179 | forward = 0 180 | start = time.time() 181 | epoch_time = 0 182 | for iter in range(train_steps): 183 | train_loss = train_one_step(model, optimizer, loss_fcn, cuda_device, feat_len, iter, device_id) 184 | # if device_id == 0: 185 | # print('Iter {} Train Loss :{} '.format(iter, train_loss)) 186 | epoch_time += time.time() - start 187 | 188 | model.eval() 189 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 190 | metric = metric.to(device_id) 191 | model.metric = metric 192 | with torch.no_grad(): 193 | for iter in range(valid_steps): 194 | valid_one_step(model, metric, cuda_device, feat_len) 195 | acc_val = metric.compute() 196 | if device_id == 0: 197 | print("Epoch:{}, Cost:{} s, Val Acc: {}".format(epoch, epoch_time, acc_val)) 198 | 199 | 200 | model.eval() 201 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 202 | metric = metric.to(device_id) 203 | model.metric = metric 204 | with torch.no_grad(): 205 | for iter in range(test_steps): 206 | test_one_step(model, metric, cuda_device, feat_len) 207 | acc = metric.compute() 208 | if device_id == 0: 209 | print("Accuracy on test data: {}".format(acc)) 210 | metric.reset() 211 | 212 | ipc_service.finalize() 213 | cleanup() 214 | 215 | def run_distribute(dist_fn, world_size, args): 216 | mp.spawn(dist_fn, 217 | args=(world_size, args), 218 | nprocs=world_size, 219 | join=True) 220 | 221 | if __name__ == "__main__": 222 | cur_path = sys.path[0] 223 | argparser = argparse.ArgumentParser("Train GNN.") 224 | argparser.add_argument('--class_num', type=int, default=2) 225 | argparser.add_argument('--features_num', type=int, default=128) 226 | argparser.add_argument('--hidden_dim', type=int, default=256) 227 | argparser.add_argument('--hops_num', type=int, default=2) 228 | argparser.add_argument('--nbrs_num', type=list, default=[25, 10]) 229 | argparser.add_argument('--drop_rate', type=float, default=0.5) 230 | argparser.add_argument('--learning_rate', type=float, default=0.003) 231 | argparser.add_argument('--epoch', type=int, default=2) 232 | argparser.add_argument('--gpu_number', type=int, default=2) 233 | args = argparser.parse_args() 234 | 235 | world_size = args.gpu_number 236 | 237 | run_distribute(worker_process, world_size, args) 238 | -------------------------------------------------------------------------------- /training_backend/legion_graphsage.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0" 4 | import sys 5 | import tempfile 6 | import argparse 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.multiprocessing as mp 12 | 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | import torch.nn.functional as Func 15 | 16 | import ipc_service 17 | import dgl 18 | from dgl.nn.pytorch import SAGEConv 19 | from dgl.heterograph import DGLBlock 20 | import time 21 | import numpy as np 22 | import torchmetrics 23 | torch.set_printoptions(threshold=np.inf) 24 | 25 | def setup(rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = '12355' 28 | # initialize the process group 29 | if torch.cuda.is_available(): 30 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 31 | else: 32 | dist.init_process_group('gloo', rank=rank, world_size=world_size) 33 | 34 | def cleanup(): 35 | dist.destroy_process_group() 36 | 37 | class SAGE(nn.Module): 38 | def __init__(self, 39 | in_feats, 40 | n_hidden, 41 | n_classes, 42 | n_layers, 43 | activation, 44 | dropout): 45 | super().__init__() 46 | self.n_layers = n_layers 47 | self.n_hidden = n_hidden 48 | self.n_classes = n_classes 49 | self.layers = nn.ModuleList() 50 | self.layers.append(SAGEConv(in_feats, n_hidden, 'mean')) 51 | for _ in range(1, n_layers - 1): 52 | self.layers.append(SAGEConv(n_hidden, n_hidden, 'mean')) 53 | self.layers.append(SAGEConv(n_hidden, n_classes, 'mean')) 54 | self.dropout = nn.Dropout(dropout) 55 | self.activation = activation 56 | 57 | def forward(self, blocks, x): 58 | h = x 59 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): 60 | h = layer(block, h) 61 | if l != len(self.layers) - 1: 62 | h = self.activation(h) 63 | h = self.dropout(h) 64 | return h 65 | 66 | def create_dgl_block(src, dst, num_src_nodes, num_dst_nodes): 67 | gidx = dgl.heterograph_index.create_unitgraph_from_coo(2, num_src_nodes, num_dst_nodes, src, dst, 'coo', row_sorted=True) 68 | g = DGLBlock(gidx, (['_N'], ['_N']), ['_E']) 69 | 70 | return g 71 | 72 | def train_one_step(model, optimizer, loss_fcn, device, feat_len, iter, device_id): 73 | 74 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 75 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 76 | 77 | blocks = [] 78 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 79 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 80 | # print(features[:100]) 81 | # print(ids[:100]) 82 | batch_pred = model(blocks, features) 83 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 84 | loss = loss_fcn(batch_pred, long_labels) 85 | optimizer.zero_grad() 86 | loss.backward() 87 | optimizer.step() 88 | 89 | torch.cuda.synchronize() 90 | ipc_service.synchronize() 91 | return 0 92 | 93 | def valid_one_step(model, metric, device, feat_len): 94 | 95 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 96 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 97 | blocks = [] 98 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 99 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 100 | batch_pred = model(blocks, features) 101 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 102 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 103 | acc = metric(batch_pred, long_labels) 104 | ipc_service.synchronize() 105 | return acc 106 | 107 | def test_one_step(model, metric, device, feat_len): 108 | 109 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 110 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 111 | blocks = [] 112 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 113 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 114 | batch_pred = model(blocks, features) 115 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 116 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 117 | acc = metric(batch_pred, long_labels) 118 | ipc_service.synchronize() 119 | return acc 120 | 121 | def worker_process(rank, world_size, args): 122 | print(f"Running GNN Training on CUDA {rank}.") 123 | device_id = rank 124 | setup(rank, world_size) 125 | cuda_device = torch.device("cuda:{}".format(device_id)) 126 | torch.cuda.set_device(cuda_device) 127 | ipc_service.initialize() 128 | train_steps, valid_steps, test_steps = ipc_service.get_steps() 129 | 130 | feat_len = args.features_num 131 | 132 | model = SAGE(in_feats=args.features_num, 133 | n_hidden=args.hidden_dim, 134 | n_classes=args.class_num, 135 | n_layers=args.hops_num, 136 | activation=Func.relu, 137 | dropout=args.drop_rate).to(cuda_device) 138 | 139 | if dist.is_initialized(): 140 | model = DDP(model, device_ids=[device_id]) 141 | loss_fcn = nn.CrossEntropyLoss() 142 | loss_fcn = loss_fcn.to(device_id) 143 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 144 | model.train() 145 | 146 | epoch_num = args.epoch 147 | 148 | for epoch in range(epoch_num): 149 | forward = 0 150 | start = time.time() 151 | epoch_time = 0 152 | for iter in range(train_steps): 153 | train_loss = train_one_step(model, optimizer, loss_fcn, cuda_device, feat_len, iter, device_id) 154 | # if device_id == 0: 155 | # print('Iter {} Train Loss :{} '.format(iter, train_loss)) 156 | epoch_time += time.time() - start 157 | 158 | model.eval() 159 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 160 | metric = metric.to(device_id) 161 | model.metric = metric 162 | with torch.no_grad(): 163 | for iter in range(valid_steps): 164 | valid_one_step(model, metric, cuda_device, feat_len) 165 | acc_val = metric.compute() 166 | if device_id == 0: 167 | print("Epoch:{}, Cost:{} s, Val Acc: {}".format(epoch, epoch_time, acc_val)) 168 | 169 | 170 | model.eval() 171 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 172 | metric = metric.to(device_id) 173 | model.metric = metric 174 | with torch.no_grad(): 175 | for iter in range(test_steps): 176 | test_one_step(model, metric, cuda_device, feat_len) 177 | acc = metric.compute() 178 | if device_id == 0: 179 | print("Accuracy on test data: {}".format(acc)) 180 | metric.reset() 181 | 182 | ipc_service.finalize() 183 | cleanup() 184 | 185 | def run_distribute(dist_fn, world_size, args): 186 | mp.spawn(dist_fn, 187 | args=(world_size, args), 188 | nprocs=world_size, 189 | join=True) 190 | 191 | if __name__ == "__main__": 192 | cur_path = sys.path[0] 193 | argparser = argparse.ArgumentParser("Train GNN.") 194 | argparser.add_argument('--class_num', type=int, default=2) 195 | argparser.add_argument('--features_num', type=int, default=128) 196 | argparser.add_argument('--hidden_dim', type=int, default=256) 197 | argparser.add_argument('--hops_num', type=int, default=2) 198 | argparser.add_argument('--nbrs_num', type=list, default=[25, 10]) 199 | argparser.add_argument('--drop_rate', type=float, default=0.5) 200 | argparser.add_argument('--learning_rate', type=float, default=0.003) 201 | argparser.add_argument('--epoch', type=int, default=2) 202 | argparser.add_argument('--gpu_number', type=int, default=2) 203 | args = argparser.parse_args() 204 | 205 | world_size = args.gpu_number 206 | 207 | run_distribute(worker_process, world_size, args) 208 | -------------------------------------------------------------------------------- /training_backend/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | import os 4 | os.environ['CUDA_HOME'] = '/usr/local/cuda' 5 | setup( 6 | name='ipcservice', 7 | ext_modules=[ 8 | CUDAExtension('ipc_service', [ 9 | 'ipc_service.cpp', 10 | 'helper_multiprocess.cpp', 11 | 'ipc_cuda_kernel.cu', 12 | ]) 13 | ], 14 | cmdclass={ 15 | 'build_ext': BuildExtension 16 | }) --------------------------------------------------------------------------------