├── .gitignore ├── Hyperion_server.py ├── IOStack ├── Makefile ├── common.cuh ├── ioctl.h ├── iostack.cuh ├── map.h ├── ssd_write.cu ├── ssdqp.cuh └── test.cu ├── README.md ├── build.sh ├── dataset ├── README.md ├── convert_to_bin.py ├── gen_legion_xtrapulp_fomat.cpp ├── gen_sets.py ├── webgraph-3.5.2.jar └── webgraph-3.6.8-deps.tar.gz ├── meta_config ├── prepare_dataset.sh ├── sampling_server ├── Makefile ├── hyperion.cpp └── src │ ├── cache │ ├── cache.cu │ ├── cache.cuh │ └── cache_impl.cuh │ ├── engine │ ├── helper_multiprocess.cu │ ├── helper_multiprocess.h │ ├── ipc_service.cu │ ├── ipc_service.h │ ├── memorypool.cu │ ├── memorypool.cuh │ ├── monitor.cuh │ ├── operator.cu │ ├── operator.h │ ├── operator_impl.cu │ ├── operator_impl.cuh │ ├── server.cu │ ├── server.h │ └── server_imp.cuh │ ├── include │ ├── buildinfo.h │ ├── hashmap.h │ ├── hashmap │ │ ├── CMakeLists.txt │ │ ├── bcht.hpp │ │ ├── benchmark_helpers.cuh │ │ ├── cht.hpp │ │ ├── cmd.hpp │ │ ├── detail │ │ │ ├── allocator.hpp │ │ │ ├── bcht_impl.cuh │ │ │ ├── benchmark_metrics.cuh │ │ │ ├── bucket.cuh │ │ │ ├── cht_impl.cuh │ │ │ ├── cuda_helpers.cuh │ │ │ ├── hash_functions.cuh │ │ │ ├── iht_impl.cuh │ │ │ ├── kernels.cuh │ │ │ ├── p2bht_impl.cuh │ │ │ ├── pair.cuh │ │ │ ├── pair_detail.hpp │ │ │ ├── ptx.cuh │ │ │ └── rng.hpp │ │ ├── genzipf.hpp │ │ ├── gpu_timer.hpp │ │ ├── iht.hpp │ │ ├── p2bht.hpp │ │ ├── perf_report.hpp │ │ └── rkg.hpp │ └── system_config.cuh │ └── storage │ ├── feature_storage.cu │ ├── feature_storage.cuh │ ├── feature_storage_impl.cuh │ ├── graph_storage.cu │ ├── graph_storage.cuh │ ├── graph_storage_impl.cuh │ ├── ioctl.h │ ├── iostack.cuh │ ├── map.h │ ├── ssdqp.cuh │ ├── storage_management.cu │ ├── storage_management.cuh │ ├── storage_management_impl.cuh │ └── userqueue.cuh ├── training_backend ├── helper_multiprocess.cpp ├── helper_multiprocess.h ├── hyperion_gat.py ├── hyperion_gat3hop.py ├── hyperion_gcn.py ├── hyperion_gcn3hop.py ├── hyperion_graphsage.py ├── hyperion_graphsage3hop.py ├── ipc_cuda_kernel.cu ├── ipc_service.cpp ├── ipc_service.h └── setup.py └── unload_ssd.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ukunion 2 | dataset/lib 3 | dataset/xtrapulp 4 | dataset/xtrapulp_result 5 | dataset/gen_legion_xtrapulp_fomat 6 | IOStack/test 7 | sampling_server/build 8 | training_backend/build 9 | training_backend/dist 10 | training_backend/ipcservice.egg-info 11 | meta_config -------------------------------------------------------------------------------- /Hyperion_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | import re 5 | import networkx as nx 6 | import math 7 | import sys 8 | 9 | sys.path.append("sampling_server/build") 10 | import hyperion 11 | 12 | def Run(args): 13 | 14 | if args.dataset_name == "products": 15 | path = args.dataset_path + "/products/" 16 | vertices_num = 2449029 17 | edges_num = 123718280 18 | features_dim = 100 19 | train_set_num = 196615 20 | valid_set_num = 39323 21 | test_set_num = 2213091 22 | elif args.dataset_name == "paper100m": 23 | path = args.dataset_path + "/paper100M/" 24 | vertices_num = 111059956 25 | edges_num = 1615685872 26 | features_dim = 128 27 | train_set_num = 11105995 28 | valid_set_num = 100000 29 | test_set_num = 100000 30 | elif args.dataset_name == "com-friendster": 31 | path = args.dataset_path + "/com-friendster/" 32 | vertices_num = 65608366 33 | edges_num = 1806067135 34 | features_dim = 256 35 | train_set_num = 6560836 36 | valid_set_num = 100000 37 | test_set_num = 100000 38 | elif args.dataset_name == "ukunion": 39 | path = args.dataset_path + "/ukunion/" 40 | vertices_num = 133633040 41 | edges_num = 5507679822 42 | features_dim = 256 43 | train_set_num = 13363304 44 | valid_set_num = 100000 45 | test_set_num = 100000 46 | elif args.dataset_name == "uk2014": 47 | path = args.dataset_path + "/uk2014/" 48 | vertices_num = 787801471 49 | edges_num = 47284178505 50 | features_dim = 128 51 | train_set_num = 78780147 52 | valid_set_num = 100000 53 | test_set_num = 100000 54 | elif args.dataset_name == "clueweb": 55 | path = args.dataset_path + "/clueweb/" 56 | vertices_num = 955207488 57 | edges_num = 42574107469 58 | features_dim = 128 59 | train_set_num = 95520748 60 | valid_set_num = 100000 61 | test_set_num = 100000 62 | elif args.dataset_name == "igb": 63 | path = args.dataset_path + "/igb/" 64 | vertices_num = 269346175 65 | edges_num = 3870892894 66 | features_dim = 256 67 | train_set_num = 2693461 68 | valid_set_num = 165 69 | test_set_num = 218 70 | else: 71 | print("invalid dataset path") 72 | exit 73 | 74 | 75 | with open("meta_config","w") as file: 76 | file.write("{} {} {} {} {} {} {} {} {} {} {} {} {} {} {} {} ".format(path, args.train_batch_size, vertices_num, edges_num, \ 77 | features_dim, train_set_num, valid_set_num, test_set_num, \ 78 | args.epoch, 0, args.ssd_number, args.num_queues_per_ssd, \ 79 | args.CPU_Topo_memory, args.GPU_Topo_memory, args.CPU_Feat_memory, args.GPU_Feat_memory)) 80 | 81 | server = hyperion.NewGPUServer() 82 | server.initialize(gpu_number=1, fanout=args.fanout) ## configure fanouts 83 | server.presc(0) 84 | server.run() 85 | server.finalize() 86 | 87 | if __name__ == "__main__": 88 | 89 | argparser = argparse.ArgumentParser("Legion Server.") 90 | argparser.add_argument('--dataset_path', type=str, default="/share/gnn_data") 91 | argparser.add_argument('--dataset_name', type=str, default="igb") 92 | argparser.add_argument('--train_batch_size', type=int, default=8000) 93 | argparser.add_argument('--fanout', type=list, default=[25, 10]) 94 | argparser.add_argument('--gpu_number', type=int, default=1) 95 | argparser.add_argument('--epoch', type=int, default=2) 96 | argparser.add_argument('--ssd_number', type=int, default=2) 97 | argparser.add_argument('--num_queues_per_ssd', type=int, default=128) 98 | argparser.add_argument('--CPU_Topo_memory', type=int, default=3300000000) 99 | argparser.add_argument('--GPU_Topo_memory', type=int, default=20000) 100 | argparser.add_argument('--CPU_Feat_memory', type=int, default=600000000) 101 | argparser.add_argument('--GPU_Feat_memory', type=int, default=2000000000) 102 | 103 | args = argparser.parse_args() 104 | 105 | Run(args) 106 | -------------------------------------------------------------------------------- /IOStack/Makefile: -------------------------------------------------------------------------------- 1 | NVCC = nvcc 2 | 3 | all: test 4 | 5 | test: test.cu common.cuh iostack.cuh ssdqp.cuh 6 | $(NVCC) $(CFLAGS) -o $@ $< -g 7 | 8 | clean: 9 | rm -f test decouple 10 | -------------------------------------------------------------------------------- /IOStack/common.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #define REG_SIZE 0x4000 // BAR 0 mapped size 4 | #define REG_CC 0x14 // addr: controller configuration 5 | #define REG_CC_EN 0x1 // mask: enable controller 6 | #define REG_CSTS 0x1c // addr: controller status 7 | #define REG_CSTS_RDY 0x1 // mask: controller ready 8 | #define REG_AQA 0x24 // addr: admin queue attributes 9 | #define REG_ASQ 0x28 // addr: admin submission queue base addr 10 | #define REG_ACQ 0x30 // addr: admin completion queue base addr 11 | #define REG_SQTDBL 0x1000 // addr: submission queue 0 tail doorbell 12 | #define REG_CQHDBL 0x1004 // addr: completion queue 0 sq_tail doorbell 13 | #define DBL_STRIDE 8 14 | #define PHASE_MASK 0x10000 // mask: phase tag 15 | #define HOST_PGSZ 0x1000 16 | #define DEVICE_PGSZ 0x10000 17 | #define CID_MASK 0xffff // mask: command id 18 | #define SC_MASK 0xff // mask: status code 19 | #define BROADCAST_NSID 0 // broadcast namespace id, 980pro set to 0 20 | #define OPCODE_SET_FEATURES 0x09 21 | #define OPCODE_CREATE_IO_CQ 0x05 22 | #define OPCODE_CREATE_IO_SQ 0x01 23 | #define OPCODE_READ 0x02 24 | #define OPCODE_WRITE 0x01 25 | #define FID_NUM_QUEUES 0x07 26 | #define LB_SIZE 0x200 27 | #define RW_RETRY_MASK 0x80000000 28 | #define SQ_ITEM_SIZE 64 29 | #define WARP_SIZE 32 30 | #define SQ_HEAD_MASK 0xffff 31 | 32 | #define MAX_IO_SIZE 4096 33 | #define LBS 512 34 | #define MAX_ITEMS (MAX_IO_SIZE / LBS) 35 | #define NUM_THREADS_PER_BLOCK 512 36 | #define ADMIN_QUEUE_DEPTH 64 37 | #define QUEUE_DEPTH 4096 38 | #define QUEUE_IOBUF_SIZE (MAX_IO_SIZE * QUEUE_DEPTH) 39 | #define NUM_PRP_ENTRIES (MAX_IO_SIZE / HOST_PGSZ) 40 | #define PRP_SIZE (NUM_PRP_ENTRIES * sizeof(uint64_t)) 41 | #define NUM_LBS_PER_SSD 0x746a5288//0x60000000 //0x66666666 // 42 | #define MAX_SSDS_SUPPORTED 16 43 | 44 | #define CHECK(ans) gpuAssert((ans), __FILE__, __LINE__) 45 | 46 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) 47 | { 48 | if (code != cudaSuccess) 49 | { 50 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 51 | if (abort) 52 | exit(1); 53 | } 54 | } 55 | 56 | __device__ void gpusleep(uint64_t cycles) 57 | { 58 | uint64_t start = clock64(); 59 | while (clock64() - start < cycles) 60 | ; 61 | } 62 | -------------------------------------------------------------------------------- /IOStack/ioctl.h: -------------------------------------------------------------------------------- 1 | #ifndef __NVM_INTERNAL_LINUX_IOCTL_H__ 2 | #define __NVM_INTERNAL_LINUX_IOCTL_H__ 3 | #ifdef __linux__ 4 | 5 | #include 6 | #include 7 | 8 | #define NVM_IOCTL_TYPE 0x80 9 | 10 | 11 | 12 | /* Memory map request */ 13 | struct nvm_ioctl_map 14 | { 15 | uint64_t vaddr_start; 16 | size_t n_pages; 17 | uint64_t* ioaddrs; 18 | }; 19 | 20 | 21 | 22 | /* Supported operations */ 23 | enum nvm_ioctl_type 24 | { 25 | NVM_MAP_HOST_MEMORY = _IOW(NVM_IOCTL_TYPE, 1, struct nvm_ioctl_map), 26 | #ifdef _CUDA 27 | NVM_MAP_DEVICE_MEMORY = _IOW(NVM_IOCTL_TYPE, 2, struct nvm_ioctl_map), 28 | #endif 29 | NVM_UNMAP_MEMORY = _IOW(NVM_IOCTL_TYPE, 3, uint64_t) 30 | }; 31 | 32 | 33 | #endif /* __linux__ */ 34 | #endif /* __NVM_INTERNAL_LINUX_IOCTL_H__ */ 35 | -------------------------------------------------------------------------------- /IOStack/map.h: -------------------------------------------------------------------------------- 1 | #ifndef __NVM_INTERNAL_LINUX_MAP_H__ 2 | #define __NVM_INTERNAL_LINUX_MAP_H__ 3 | #ifdef __linux__ 4 | 5 | #include "linux/ioctl.h" 6 | #include "dma.h" 7 | 8 | 9 | /* 10 | * What kind of memory are we mapping. 11 | */ 12 | enum mapping_type 13 | { 14 | MAP_TYPE_CUDA = 0x1, // CUDA device memory 15 | MAP_TYPE_HOST = 0x2, // Host memory (RAM) 16 | MAP_TYPE_API = 0x4 // Allocated by the API (RAM) 17 | }; 18 | 19 | 20 | 21 | /* 22 | * Mapping container 23 | */ 24 | struct ioctl_mapping 25 | { 26 | enum mapping_type type; // What kind of memory 27 | void* buffer; 28 | struct va_range range; // Memory range descriptor 29 | }; 30 | 31 | 32 | #endif /* __linux__ */ 33 | #endif /* __NVM_INTERNAL_LINUX_MAP_H__ */ 34 | -------------------------------------------------------------------------------- /IOStack/ssd_write.cu: -------------------------------------------------------------------------------- 1 | #include "iostack_decouple.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #define TEST_SIZE 0x17D7840000 9 | #define NUM_QUEUES_PER_SSD 128 10 | #define NUM_SSDS 6 11 | 12 | // Macro for checking cuda errors following a cuda launch or api call 13 | #define cudaCheckError() \ 14 | { \ 15 | cudaError_t e = cudaGetLastError(); \ 16 | if (e != cudaSuccess) { \ 17 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 18 | cudaGetErrorString(e)); \ 19 | exit(EXIT_FAILURE); \ 20 | } \ 21 | } 22 | 23 | 24 | __device__ float **IO_buf_base; 25 | 26 | __device__ uint64_t seed; 27 | __global__ void gen_test_data(int ssd_id, int64_t req_id, int block_id) 28 | { 29 | for (int i = 0; i < MAX_IO_SIZE / 4; i++) 30 | { 31 | // seed = seed * 0x5deece66d + 0xb; 32 | IO_buf_base[ssd_id][i] = req_id; 33 | // if(i % (ITEM_SIZE / 4) == 0){ 34 | // IO_buf_base[ssd_id][i] = block_id * 8 + (i / (ITEM_SIZE / 4));//i + block_id * MAX_IO_SIZE / 4;//req_id;// * MAX_IO_SIZE / 8 + i; 35 | // }else{ 36 | // IO_buf_base[ssd_id][i] = i%(ITEM_SIZE / 4); 37 | // } 38 | } 39 | } 40 | 41 | __global__ void gen_feat_data(int ssd_id, int block_id, float* feature){ 42 | for (int64_t i = 0; i < MAX_IO_SIZE / 4; i++) 43 | { 44 | IO_buf_base[ssd_id][i] = feature[i + int64_t((MAX_IO_SIZE / 4))*block_id]; 45 | } 46 | } 47 | 48 | __global__ void check_test_data(float *app_buf, int idx) 49 | { 50 | for (int i = 0; i < MAX_IO_SIZE / 4; i++) 51 | { 52 | seed = seed * 0x5deece66d + 0xb; 53 | if (app_buf[i] != idx * MAX_IO_SIZE / 4 + i) 54 | { 55 | printf("check failed at block %d, i = %d, read %lx, expected %x\n", idx, i, app_buf[i], idx * MAX_IO_SIZE / 8 + i); 56 | assert(0); 57 | } 58 | } 59 | } 60 | 61 | __global__ void fill_app_buf(float *app_buf) 62 | { 63 | for (int i = 0; i < TEST_SIZE / 4; i++) 64 | app_buf[i] = 0; 65 | } 66 | 67 | void mmap_features_read(std::string &features_file, float* features){ 68 | int64_t n_idx = 0; 69 | int32_t fd = open(features_file.c_str(), O_RDONLY); 70 | if(fd == -1){ 71 | std::cout<<"cannout open file: "< lbs; 108 | 109 | int percent = 1; 110 | clock_t clstart = clock(); 111 | cudaCheckError(); 112 | 113 | for (int64_t i = 0; i < num_reqs; i++) 114 | { 115 | uint64_t lb; 116 | lb = i; 117 | 118 | int ssd_id = lb * MAX_ITEMS / NUM_LBS_PER_SSD; 119 | 120 | // gen_feat_data<<<1,1>>>(ssd_id, i, host_float_feature); 121 | gen_test_data<<<1, 1>>>(ssd_id, i, lb); 122 | // cudaCheckError(); 123 | 124 | iostack.write_data(ssd_id, (lb * MAX_ITEMS) % NUM_LBS_PER_SSD, MAX_IO_SIZE / LB_SIZE); 125 | cudaCheckError(); 126 | 127 | if(i % 10000000 == 0){ 128 | printf("req %lu\n", i); 129 | } 130 | if (i >= num_reqs / 1000 * percent) 131 | { 132 | double eta = (clock() - clstart) / (double)CLOCKS_PER_SEC / percent * (1000 - percent); 133 | fprintf(stderr, "generating test data: %d%% done, eta %.0lfs\r", percent/10, eta); 134 | percent++; 135 | } 136 | cudaCheckError(); 137 | } 138 | CHECK(cudaDeviceSynchronize()); 139 | std::cout<<"Finish Writing SSD\n"; 140 | } -------------------------------------------------------------------------------- /IOStack/ssdqp.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "common.cuh" 5 | class SSDQueuePair 6 | { 7 | public: 8 | volatile uint32_t *sq; 9 | volatile uint32_t *cq; 10 | uint32_t sq_tail; 11 | uint32_t sq_tail_old; 12 | uint32_t cq_head; 13 | uint32_t cmd_id; // also number of commands submitted 14 | uint32_t namespace_id; 15 | uint32_t *sqtdbl, *cqhdbl; 16 | uint32_t *cmd_id_to_req_id; 17 | uint32_t *cmd_id_to_sq_pos; 18 | bool *sq_entry_busy; 19 | uint32_t queue_depth; 20 | uint32_t num_completed; 21 | 22 | __host__ __device__ SSDQueuePair() 23 | { 24 | } 25 | 26 | __host__ __device__ SSDQueuePair(volatile uint32_t *sq, volatile uint32_t *cq, uint32_t namespace_id, uint32_t *sqtdbl, uint32_t *cqhdbl, uint32_t queue_depth, uint32_t *cmd_id_to_req_id = nullptr, uint32_t *cmd_id_to_sq_pos = nullptr, bool *sq_entry_busy = nullptr) 27 | : sq(sq), cq(cq), sq_tail(0), cq_head(0), cmd_id(0), namespace_id(namespace_id), sqtdbl(sqtdbl), cqhdbl(cqhdbl), cmd_id_to_req_id(cmd_id_to_req_id), cmd_id_to_sq_pos(cmd_id_to_sq_pos), sq_entry_busy(sq_entry_busy), queue_depth(queue_depth), num_completed(0) 28 | { 29 | sq_tail_old = 0; 30 | } 31 | 32 | __host__ __device__ void submit(uint32_t &cid, uint32_t opcode, uint64_t prp1, uint64_t prp2, uint32_t dw10, uint32_t dw11, uint32_t dw12 = 0) 33 | { 34 | // printf("%lx %lx %x %x %x %x %x\n", prp1, prp2, dw10, dw11, dw12, opcode, cmd_id); 35 | fill_sq(cmd_id, sq_tail, opcode, prp1, prp2, dw10, dw11, dw12); 36 | sq_tail = (sq_tail + 1) % queue_depth; 37 | *sqtdbl = sq_tail; 38 | cid = cmd_id; 39 | cmd_id = (cmd_id + 1) & CID_MASK; 40 | } 41 | 42 | __host__ __device__ void fill_sq(uint32_t cid, uint32_t pos, uint32_t opcode, uint64_t prp1, uint64_t prp2, uint32_t dw10, uint32_t dw11, uint32_t dw12 = 0, uint32_t req_id = 0xffffffff) 43 | { 44 | // if (req_id == 1152) 45 | sq[pos * 16 + 0] = opcode | (cid << 16); 46 | sq[pos * 16 + 1] = namespace_id; 47 | sq[pos * 16 + 6] = prp1 & 0xffffffff; 48 | sq[pos * 16 + 7] = prp1 >> 32; 49 | sq[pos * 16 + 8] = prp2 & 0xffffffff; 50 | sq[pos * 16 + 9] = prp2 >> 32; 51 | sq[pos * 16 + 10] = dw10; 52 | sq[pos * 16 + 11] = dw11; 53 | sq[pos * 16 + 12] = dw12; 54 | // printf("%u %u %u %u %u %u %u %u\n", opcode | (cid << 16), namespace_id, prp1 & 0xffffffff, prp1 >> 32, prp2 & 0xffffffff, prp2 >> 32, dw10, dw11, dw12); 55 | // printf("%u, %u\n", namespace_id, opcode); 56 | if (cmd_id_to_req_id) 57 | cmd_id_to_req_id[cid % queue_depth] = req_id; 58 | if (cmd_id_to_sq_pos) 59 | cmd_id_to_sq_pos[cid % queue_depth] = pos; 60 | if (sq_entry_busy) 61 | sq_entry_busy[pos] = true; 62 | } 63 | 64 | __host__ __device__ void poll(uint32_t &code, uint32_t cid) 65 | { 66 | uint32_t current_phase = ((cmd_id - 1) / queue_depth) & 1; 67 | uint32_t status = cq[cq_head * 4 + 3]; 68 | while (((status & PHASE_MASK) >> 16) == current_phase) 69 | status = cq[cq_head * 4 + 3]; 70 | if ((status & CID_MASK) != cid) 71 | { 72 | printf("expected cid: %d, actual cid: %d\n", cid, status & CID_MASK); 73 | assert(0); 74 | } 75 | cq_head = (cq_head + 1) % queue_depth; 76 | *cqhdbl = cq_head; 77 | code = (status >> 17) & SC_MASK; 78 | num_completed++; 79 | } 80 | 81 | __device__ uint32_t poll_range(int expected_sq_head, bool should_break) 82 | { 83 | // printf("cmd_id: %d, size: %d, current_phase: %d\n", cmd_id, size, current_phase); 84 | int i; 85 | uint32_t last_sq_head = ~0U; 86 | int last_num_completed = num_completed; 87 | int thread_id = threadIdx.x + blockIdx.x * blockDim.x; 88 | for (i = cq_head; (num_completed & CID_MASK) != cmd_id; i = (i + 1) % queue_depth) 89 | { 90 | uint32_t current_phase = (num_completed / queue_depth) & 1; 91 | uint32_t status = cq[i * 4 + 3]; 92 | uint64_t start = clock64(); 93 | while (((status & PHASE_MASK) >> 16) == current_phase) 94 | { 95 | status = cq[i * 4 + 3]; 96 | if (clock64() - start > 1000000000) 97 | { 98 | printf("timeout sq_tail=%d, cq_head=%d, i=%d, num_completed=%d, cmd_id=%d\n", sq_tail, cq_head, i, num_completed, cmd_id); 99 | printf("last_sq_head: %d, expected_sq_head: %d\n", last_sq_head, expected_sq_head); 100 | // int thread_id = blockIdx.x * blockDim.x + threadIdx.x; 101 | // if (thread_id) 102 | // return 0; 103 | // for (int m = 0; m < queue_depth; m++) 104 | // { 105 | // printf("SQE %d\n", m); 106 | // for (int n = 0; n < 16; n++) 107 | // printf("DW%2d, %08x\n", n, sq[m * 16 + n]); 108 | // } 109 | // for (int m = 0; m < queue_depth; m++) 110 | // { 111 | // printf("CQE %d\n", m); 112 | // for (int n = 0; n < 4; n++) 113 | // printf("DW%2d, %08x\n", n, cq[m * 4 + n]); 114 | // } 115 | return 1; 116 | } 117 | } 118 | int cmd_id = status & CID_MASK; 119 | int sq_pos = cmd_id_to_sq_pos[cmd_id % queue_depth]; 120 | if ((status >> 17) & SC_MASK) 121 | { 122 | printf("cq[%d] status: 0x%x, cid: %d\n", i, (status >> 17) & SC_MASK, status & CID_MASK); 123 | int req_id = cmd_id_to_req_id[cmd_id % queue_depth]; 124 | printf("req_id: %d, sq_pos: %d\n", req_id, sq_pos); 125 | // for (int i = 0; i < 16; i++) 126 | // printf("%08x ", sq[sq_pos * 16 + i]); 127 | // printf("\n"); 128 | return (status >> 17) & SC_MASK; 129 | } 130 | last_sq_head = cq[i * 4 + 2] & SQ_HEAD_MASK; 131 | sq_entry_busy[sq_pos] = false; 132 | // printf("thread %d freed sq_pos %d\n", thread_id, sq_pos); 133 | num_completed++; 134 | if (should_break && ((cq[i * 4 + 2] & SQ_HEAD_MASK) - expected_sq_head + queue_depth) % queue_depth <= WARP_SIZE) 135 | { 136 | // printf("cq[%d] sq_head: %d, expected_sq_head: %d\n", i, cq[i * 4 + 2] & SQ_HEAD_MASK, expected_sq_head); 137 | i = (i + 1) % queue_depth; 138 | if (num_completed - last_num_completed > 64) 139 | printf("%d: %d completed\n", thread_id, num_completed - last_num_completed); 140 | break; 141 | } 142 | } 143 | if (i != cq_head) 144 | { 145 | cq_head = i; 146 | // printf("cq_head is %p, set cqhdbl to %d\n", cqhdbl, cq_head); 147 | *cqhdbl = cq_head; 148 | } 149 | return 0; 150 | } 151 | 152 | __host__ __device__ uint32_t poll_until_sq_entry_free(int expected_sq_pos) 153 | { 154 | int thread_id = blockIdx.x * blockDim.x + threadIdx.x; 155 | int last_num_completed = num_completed; 156 | // printf("thread %d want to free sq_pos: %d num_completed %d cmd_id %d\n", thread_id, expected_sq_pos, num_completed, cmd_id); 157 | int i; 158 | for (i = cq_head; (num_completed & CID_MASK) != cmd_id; i = (i + 1) % queue_depth) 159 | { 160 | uint32_t current_phase = (num_completed / queue_depth) & 1; 161 | uint32_t status = cq[i * 4 + 3]; 162 | while (((status & PHASE_MASK) >> 16) == current_phase) 163 | status = cq[i * 4 + 3]; 164 | int cmd_id = status & CID_MASK; 165 | int sq_pos = cmd_id_to_sq_pos[cmd_id % queue_depth]; 166 | if ((status >> 17) & SC_MASK) 167 | { 168 | printf("cq[%d] status: 0x%x, cid: %d\n", i, (status >> 17) & SC_MASK, status & CID_MASK); 169 | int req_id = cmd_id_to_req_id[cmd_id % queue_depth]; 170 | printf("req_id: %d, sq_pos: %d\n", req_id, sq_pos); 171 | // for (int i = 0; i < 16; i++) 172 | // printf("%08x ", sq[sq_pos * 16 + i]); 173 | // printf("\n"); 174 | return (status >> 17) & SC_MASK; 175 | } 176 | sq_entry_busy[sq_pos] = false; 177 | // printf("thread %d manually freed sq_pos %d\n", thread_id, sq_pos); 178 | num_completed++; 179 | if (sq_pos == expected_sq_pos) 180 | { 181 | cq_head = (i + 1) % queue_depth; 182 | // printf("cq_head is %p, set cqhdbl to %d\n", cqhdbl, cq_head); 183 | *cqhdbl = cq_head; 184 | if (num_completed - last_num_completed > 64) 185 | printf("%d: %d completed\n", thread_id, num_completed - last_num_completed); 186 | return 0; 187 | } 188 | } 189 | // printf("thread %d failed to free sq_pos %d\n", thread_id, expected_sq_pos); 190 | return 1; 191 | } 192 | }; 193 | -------------------------------------------------------------------------------- /IOStack/test.cu: -------------------------------------------------------------------------------- 1 | #include "iostack.cuh" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // Configuration parameters (replace macros with variables) 10 | constexpr size_t TEST_SIZE = 0x10000000; // Test data size, in bytes 11 | constexpr size_t APP_BUF_SIZE = 0x10000000; // Application buffer size 12 | constexpr int NUM_QUEUES_PER_SSD = 128; // Number of queues per SSD 13 | constexpr int NUM_SSDS = 2; // Number of SSDs 14 | constexpr int IO_SIZE = 4096; // Total IO size of a single request, in bytes 15 | 16 | // Kernel to initialize application buffer 17 | __global__ void fill_app_buf(uint64_t *app_buf) { 18 | for (int i = 0; i < TEST_SIZE / 8; i++) { 19 | app_buf[i] = 0; 20 | } 21 | } 22 | 23 | // Main function 24 | int main() { 25 | // Initialize IOStack 26 | IOStack iostack(NUM_SSDS, NUM_QUEUES_PER_SSD, 1, 32); 27 | 28 | // Allocate application buffer 29 | uint64_t *app_buf; 30 | CHECK(cudaMalloc(&app_buf, APP_BUF_SIZE)); 31 | fill_app_buf<<<1, 1>>>(app_buf); 32 | 33 | // Allocate and initialize request buffers 34 | int num_reqs = TEST_SIZE / IO_SIZE; 35 | IOReq *reqs; 36 | CHECK(cudaMalloc(&reqs, sizeof(IOReq) * num_reqs)); 37 | IOReq *h_reqs; 38 | CHECK(cudaHostAlloc(&h_reqs, sizeof(IOReq) * num_reqs, cudaHostAllocMapped)); 39 | 40 | std::unordered_set lbs; 41 | srand(time(NULL)); 42 | 43 | int percent = 1; 44 | clock_t clstart = clock(); 45 | 46 | // Generate test requests 47 | for (int i = 0; i < num_reqs; i++) { 48 | uint64_t lb; 49 | do { 50 | uint64_t idx = (((unsigned long)rand() << 31) | rand()); 51 | lb = (idx % NUM_SSDS) * (NUM_LBS_PER_SSD / MAX_ITEMS) + idx % (NUM_LBS_PER_SSD / MAX_ITEMS); 52 | } while (lbs.find(lb) != lbs.end()); 53 | lbs.insert(lb); 54 | 55 | h_reqs[i].start_lb = lb * MAX_ITEMS; 56 | h_reqs[i].num_items = MAX_ITEMS; 57 | for (int j = 0; j < MAX_ITEMS; j++) { 58 | h_reqs[i].dest_addr[j] = (uint64_t)(app_buf + (1ll * i * IO_SIZE + j * LBS) % APP_BUF_SIZE / sizeof(uint64_t)); 59 | } 60 | 61 | if (i >= num_reqs / 100 * percent) { 62 | double eta = (clock() - clstart) / (double)CLOCKS_PER_SEC / percent * (100 - percent); 63 | fprintf(stderr, "generating test data: %d%% done, eta %.0lfs\r", percent, eta); 64 | percent++; 65 | } 66 | } 67 | CHECK(cudaDeviceSynchronize()); 68 | 69 | // Copy requests to device memory 70 | CHECK(cudaMemcpy(reqs, h_reqs, sizeof(IOReq) * num_reqs, cudaMemcpyHostToDevice)); 71 | 72 | // Run IO requests multiple times and measure performance 73 | int repeat = 10; 74 | cudaEvent_t start, stop; 75 | CHECK(cudaEventCreate(&start)); 76 | CHECK(cudaEventCreate(&stop)); 77 | CHECK(cudaEventRecord(start)); 78 | 79 | fprintf(stderr, "starting do_io_req...\n"); 80 | for (int i = 0; i < repeat; i++) { 81 | iostack.io_submission(reqs, 10240, 0); 82 | iostack.io_submission(reqs + 10240, 10240, 0); 83 | iostack.io_submission(reqs + 20480, 10240, 0); 84 | iostack.io_submission(reqs + 30720, num_reqs - 30720, 0); 85 | iostack.io_completion(0); 86 | } 87 | 88 | CHECK(cudaEventRecord(stop)); 89 | CHECK(cudaEventSynchronize(stop)); 90 | float ms; 91 | CHECK(cudaEventElapsedTime(&ms, start, stop)); 92 | 93 | // Output performance results 94 | fprintf(stderr, "do_io_req takes %f ms\n", ms); 95 | fprintf(stderr, "%dB random read bandwidth: %f MiB/s\n", IO_SIZE, TEST_SIZE * repeat / (1024.0 * 1024.0) / (ms / 1000)); 96 | return 0; 97 | } 98 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | cd sampling_server && \ 2 | make clean && make -j 8 && \ 3 | cd .. 4 | cd training_backend && \ 5 | python setup.py install && \ 6 | cd .. 7 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # This file elaborates more on dataset preparing 2 | 3 | ## Legion Format 4 | Take uk-union as an example 5 | ``` 6 | Edge: uk-union/edge_src, uk-union/edge_dst ## These are topology files in CSR format. edge_src is int64, edge_dst is int32. 7 | feature: uk-union/features ## size = number of vertices x feature size 8 | label: uk-union/labels ## size = number of vertices 9 | partition: uk-union/partition ## partition_id for each vertex 10 | ``` 11 | ## Customize your datasets 12 | ``` 13 | cd dataset/ 14 | ``` 15 | Creature enviroments for webgraph 16 | ``` 17 | mkdir lib 18 | mv webgraph-3.5.2.jar lib/ 19 | tar -xzvf webgraph-3.6.8-deps.tar.gz -C lib 20 | ``` 21 | Take uk-union for example 22 | ``` 23 | mkdir ukunion 24 | cd ukunion 25 | wget http://data.law.di.unimi.it/webdata/uk-union-2006-06-2007-05/uk-union-2006-06-2007-05-underlying.graph 26 | wget http://data.law.di.unimi.it/webdata/uk-union-2006-06-2007-05/uk-union-2006-06-2007-05-underlying.properties 27 | cd .. 28 | java -cp "lib/*" it.unimi.dsi.webgraph.ArcListASCIIGraph ukunion/uk-union-2006-06-2007-05-underlying ukunion/ukunion-edgelist.txt 29 | 30 | mkdir xtrapulp_result 31 | # generate legion-format edge_src edge_dst, and the input of xtrapulp 32 | g++ gen_legion_xtrapulp_fomat.cpp -o gen_legion_xtrapulp_fomat 33 | ./gen_legion_xtrapulp_fomat ukunion ukunion-edgelist.txt 34 | # generate training sets, validation sets, and test sets 35 | python gen_sets.py --dataset_name ukunion 36 | 37 | ``` 38 | 39 | # 2. Graph partitioning 40 | ## Install MPI 41 | ``` 42 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.0.tar.gz 43 | tar zxf openmpi-3.1.0.tar.gz 44 | cd openmpi-3.1.0 45 | sudo ./configure --prefix=/usr/local/openmp 46 | sudo make 47 | sudo make install 48 | MPI_HOME=/usr/local/openmpi 49 | export PATH=${MPI_HOME}/bin:$PATH 50 | export LD_LIBRARY_PATH=${MPI_HOME}/lib:$LD_LIBRARY_PATH 51 | export MANPATH=${MPI_HOME}/share/man:$MANPATH 52 | 53 | # or the instructions in the following 54 | # sudo apt-get install openmpi-bin openmpi-doc libopenmpi-dev 55 | # sudo apt-get install mpich libmpich-dev 56 | 57 | ``` 58 | ## install xtrapulp, refer to https://github.com/luoxiaojian/xtrapulp 59 | ``` 60 | git clone https://github.com/luoxiaojian/xtrapulp.git 61 | mv ukunion_xtraformat xtrapulp/ 62 | cd xtrapulp 63 | make 64 | make libxtrapulp 65 | cd ../../ 66 | ``` -------------------------------------------------------------------------------- /dataset/convert_to_bin.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # 假设你的文件名是 'data.txt' 4 | file_path = 'data.txt' 5 | 6 | # 使用pandas的read_csv函数读取文件 7 | df = pd.read_csv(file_path, header=None, delimiter="\s+") 8 | 9 | # 转换为numpy数组 10 | data = df.to_numpy() 11 | 12 | print(data) 13 | -------------------------------------------------------------------------------- /dataset/gen_legion_xtrapulp_fomat.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | using namespace std; 13 | 14 | int64_t node_nums = 133633040; 15 | int64_t edge_nums = 5507679822; 16 | 17 | #define ten_e 1000000000 18 | #define one_e 100000000 19 | 20 | vector> vc; 21 | vector indptr; 22 | vector indices; 23 | vector edges; 24 | 25 | int64_t lines = 0; 26 | int cnt = 1; 27 | int reid = 0; 28 | int64_t qq = one_e; 29 | int64_t gl_nums = 0; 30 | bool flag[300000000]; 31 | int32_t* idmap; 32 | 33 | void read(std::string &file_path) { 34 | int fd = open(file_path.c_str(), O_RDONLY); 35 | int64_t buf_size = lseek(fd, 0, SEEK_END); 36 | char* buf = (char*)mmap(NULL, buf_size, PROT_READ, MAP_PRIVATE, fd, 0); 37 | const char* buf_end = buf + buf_size; 38 | int src, dst; 39 | std::string str = ""; 40 | while(buf < buf_end) { 41 | if(*buf == '\t'){ 42 | src = stoi(str); 43 | str = ""; 44 | buf++; 45 | continue; 46 | } 47 | if (*buf == '\n') { 48 | lines++; 49 | if(lines%qq==0){ 50 | cout<<"loaded "< " << std::endl; 113 | return 1; 114 | } 115 | std::string file_dir = argv[1]; 116 | std::string file_name = argv[2]; 117 | std::string file_path = file_dir + "/" + file_name; 118 | // string file_path = "ukunion/ukunion-edgelist.txt"; 119 | 120 | memset(flag,false,sizeof(flag)); 121 | idmap = (int32_t*)malloc(int64_t(int64_t(300000000) * sizeof(int32_t))); 122 | 123 | cout<<"edges loading:"< 2 | #include 3 | #include "server.h" // 包含 Server 类和工厂函数的定义 4 | 5 | namespace py = pybind11; 6 | 7 | PYBIND11_MODULE(hyperion, m) { 8 | m.doc() = "Hyperion Python bindings"; 9 | 10 | // 绑定 Server 接口 11 | py::class_(m, "Server") 12 | .def("initialize", &Server::Initialize, "Initialize the server", 13 | py::arg("gpu_number") = 1, py::arg("fanout") = std::vector{}) 14 | .def("presc", &Server::PreSc, "Set cache aggregate mode", py::arg("mode") = 0) 15 | .def("run", &Server::Run, "Run the server") 16 | .def("finalize", &Server::Finalize, "Finalize the server"); 17 | 18 | // 工厂函数绑定 19 | m.def("NewGPUServer", &NewGPUServer, "Create a new GPU Server instance"); 20 | } 21 | -------------------------------------------------------------------------------- /sampling_server/src/cache/cache.cuh: -------------------------------------------------------------------------------- 1 | #ifndef CACHE_H 2 | #define CACHE_H 3 | 4 | #include "graph_storage.cuh" 5 | #include "feature_storage.cuh" 6 | 7 | #include 8 | #include 9 | 10 | class CacheController{ 11 | public: 12 | virtual ~CacheController() = default; 13 | 14 | virtual void Initialize( 15 | int32_t dev_id, 16 | int32_t total_num_nodes) = 0; 17 | 18 | virtual void Finalize() = 0; 19 | 20 | virtual void FindFeat( 21 | int32_t* sampled_ids, 22 | int32_t* cache_offset, 23 | int32_t* node_counter, 24 | int32_t op_id, 25 | void* stream) = 0; 26 | 27 | virtual void FindTopo(int32_t* input_ids, 28 | int32_t* partition_index, 29 | int32_t* partition_offset, 30 | int32_t batch_size, 31 | int32_t op_id, 32 | void* strm_hdl, 33 | int32_t device_id) = 0; 34 | 35 | // virtual void FindTopoSSD(int32_t* sampled_ids, 36 | // int32_t* cache_offset, 37 | // int32_t* node_counter, 38 | // int32_t op_id, 39 | // void* stream) = 0; 40 | 41 | virtual void CacheProfiling( 42 | int32_t* sampled_ids, 43 | int32_t* agg_src_id, 44 | int32_t* agg_dst_id, 45 | int32_t* agg_src_off, 46 | int32_t* agg_dst_off, 47 | int32_t* node_counter, 48 | int32_t* edge_counter, 49 | bool is_presc, 50 | void* stream) = 0; 51 | 52 | virtual void InitializeMap(int node_capacity, int edge_capacity) = 0; 53 | 54 | virtual void UnifiedInsert(int32_t* QF, int32_t* QT, int32_t gpu_feat_num, int32_t cpu_feat_num, int32_t gpu_topo_num, int32_t cpu_topo_num) = 0; 55 | 56 | virtual void AccessCount( 57 | int32_t* d_key, 58 | int32_t num_keys, 59 | void* stream) = 0; 60 | 61 | virtual unsigned long long int* GetNodeAccessedMap() = 0; 62 | 63 | virtual unsigned long long int* GetEdgeAccessedMap() = 0; 64 | 65 | virtual int32_t MaxIdNum() = 0; 66 | }; 67 | 68 | CacheController* NewPreSCCacheController(int32_t train_step, int32_t device_count); 69 | 70 | class UnifiedCache{ 71 | public: 72 | void Initialize( 73 | int32_t float_feature_len, 74 | int32_t train_step, 75 | int32_t device_count, 76 | int64_t cpu_topo_size, 77 | int64_t gpu_topo_size, 78 | int64_t cpu_feat_size, 79 | int64_t gpu_feat_size 80 | ); 81 | 82 | void InitializeCacheController( 83 | int32_t dev_id, 84 | int32_t total_num_nodes); 85 | 86 | void Finalize(int32_t dev_id); 87 | 88 | void FindFeat( 89 | int32_t* sampled_ids, 90 | int32_t* cache_offset, 91 | int32_t* node_counter, 92 | int32_t op_id, 93 | void* stream, 94 | int32_t dev_id); 95 | 96 | void FindTopo( 97 | int32_t* input_ids, 98 | int32_t* partition_index, 99 | int32_t* partition_offset, 100 | int32_t batch_size, 101 | int32_t op_id, 102 | void* strm_hdl, 103 | int32_t dev_id); 104 | 105 | void CacheProfiling( 106 | int32_t* sampled_ids, 107 | int32_t* agg_src_id, 108 | int32_t* agg_dst_id, 109 | int32_t* agg_src_off, 110 | int32_t* agg_dst_off, 111 | int32_t* node_counter, 112 | int32_t* edge_counter, 113 | void* stream, 114 | int32_t dev_id); 115 | 116 | void AccessCount( 117 | int32_t* d_key, 118 | int32_t num_keys, 119 | void* stream, 120 | int32_t dev_id); 121 | 122 | void HybridInit(FeatureStorage* feature, GraphStorage* graph); 123 | 124 | int32_t MaxIdNum(int32_t dev_id); 125 | 126 | unsigned long long int* GetEdgeAccessedMap(int32_t dev_id); 127 | 128 | void FeatCacheLookup(int32_t* sampled_ids, int32_t* cache_index, 129 | int32_t* node_counter, float* dst_float_buffer, 130 | int32_t op_id, int32_t dev_id, cudaStream_t strm_hdl); 131 | 132 | private: 133 | int32_t NodeCapacity(int32_t dev_id); 134 | 135 | int32_t CPUCapacity(); 136 | 137 | int32_t GPUCapacity(); 138 | 139 | float* Float_Feature_Cache(int32_t dev_id);//return all features 140 | 141 | float** Global_Float_Feature_Cache(int32_t dev_id); 142 | 143 | std::vector dev_ids_;/*valid device, indexed by device id, False means invalid, True means valid*/ 144 | 145 | int32_t device_count_; 146 | 147 | std::vector cache_controller_; 148 | 149 | std::vector QF_; 150 | std::vector QT_; 151 | std::vector GF_; 152 | std::vector GT_; 153 | std::vector AF_; 154 | std::vector AT_; 155 | int Kc_; 156 | int Kg_; 157 | 158 | std::vector node_capacity_; 159 | std::vector edge_capacity_; 160 | 161 | int32_t cpu_cache_capacity_;//for legion ssd 162 | int32_t gpu_cache_capacity_;//for legion ssd 163 | 164 | int64_t cpu_topo_size_; 165 | int64_t gpu_topo_size_; 166 | int64_t cpu_feat_size_; 167 | int64_t gpu_feat_size_; 168 | 169 | int64_t cpu_topo_num_; 170 | int64_t gpu_topo_num_; 171 | int64_t cpu_feat_num_; 172 | int64_t gpu_feat_num_; 173 | 174 | int64_t cache_memory_; 175 | std::vector sidx_; 176 | 177 | std::vector int_feature_cache_; 178 | std::vector float_feature_cache_; 179 | std::vector d_float_feature_cache_ptr_; 180 | 181 | int32_t float_feature_len_; 182 | int32_t total_num_nodes_; 183 | float* cpu_float_features_; 184 | 185 | bool is_presc_; 186 | }; 187 | 188 | 189 | 190 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/helper_multiprocess.cu: -------------------------------------------------------------------------------- 1 | #include "helper_multiprocess.h" 2 | #include 3 | #include 4 | 5 | int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info) 6 | { 7 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 8 | info->size = sz; 9 | info->shmHandle = CreateFileMapping(INVALID_HANDLE_VALUE, 10 | NULL, 11 | PAGE_READWRITE, 12 | 0, 13 | (DWORD)sz, 14 | name); 15 | if (info->shmHandle == 0) { 16 | return GetLastError(); 17 | } 18 | 19 | info->addr = MapViewOfFile(info->shmHandle, FILE_MAP_ALL_ACCESS, 0, 0, sz); 20 | if (info->addr == NULL) { 21 | return GetLastError(); 22 | } 23 | 24 | return 0; 25 | #else 26 | int status = 0; 27 | 28 | info->size = sz; 29 | 30 | info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777); 31 | if (info->shmFd < 0) { 32 | return errno; 33 | } 34 | 35 | status = ftruncate(info->shmFd, sz); 36 | if (status != 0) { 37 | return status; 38 | } 39 | 40 | info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); 41 | if (info->addr == NULL) { 42 | return errno; 43 | } 44 | 45 | return 0; 46 | #endif 47 | } 48 | 49 | int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info) 50 | { 51 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 52 | info->size = sz; 53 | 54 | info->shmHandle = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name); 55 | if (info->shmHandle == 0) { 56 | return GetLastError(); 57 | } 58 | 59 | info->addr = MapViewOfFile(info->shmHandle, FILE_MAP_ALL_ACCESS, 0, 0, sz); 60 | if (info->addr == NULL) { 61 | return GetLastError(); 62 | } 63 | 64 | return 0; 65 | #else 66 | info->size = sz; 67 | 68 | info->shmFd = shm_open(name, O_RDWR, 0777); 69 | if (info->shmFd < 0) { 70 | return errno; 71 | } 72 | 73 | info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); 74 | if (info->addr == NULL) { 75 | return errno; 76 | } 77 | 78 | return 0; 79 | #endif 80 | } 81 | 82 | void sharedMemoryClose(sharedMemoryInfo *info) 83 | { 84 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 85 | if (info->addr) { 86 | UnmapViewOfFile(info->addr); 87 | } 88 | if (info->shmHandle) { 89 | CloseHandle(info->shmHandle); 90 | } 91 | #else 92 | if (info->addr) { 93 | munmap(info->addr, info->size); 94 | } 95 | if (info->shmFd) { 96 | close(info->shmFd); 97 | } 98 | #endif 99 | } 100 | 101 | int spawnProcess(Process *process, const char *app, char * const *args) 102 | { 103 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 104 | STARTUPINFO si = {0}; 105 | BOOL status; 106 | size_t arglen = 0; 107 | size_t argIdx = 0; 108 | std::string arg_string; 109 | memset(process, 0, sizeof(*process)); 110 | 111 | while (*args) { 112 | arg_string.append(*args).append(1, ' '); 113 | args++; 114 | } 115 | 116 | status = CreateProcess(app, LPSTR(arg_string.c_str()), NULL, NULL, FALSE, 0, NULL, NULL, &si, process); 117 | 118 | return status ? 0 : GetLastError(); 119 | #else 120 | *process = fork(); 121 | if (*process == 0) { 122 | if (0 > execvp(app, args)) { 123 | return errno; 124 | } 125 | } 126 | else if (*process < 0) { 127 | return errno; 128 | } 129 | return 0; 130 | #endif 131 | } 132 | 133 | int waitProcess(Process *process) 134 | { 135 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 136 | DWORD exitCode; 137 | WaitForSingleObject(process->hProcess, INFINITE); 138 | GetExitCodeProcess(process->hProcess, &exitCode); 139 | CloseHandle(process->hProcess); 140 | CloseHandle(process->hThread); 141 | return (int)exitCode; 142 | #else 143 | int status = 0; 144 | do { 145 | if (0 > waitpid(*process, &status, 0)) { 146 | return errno; 147 | } 148 | } while (!WIFEXITED(status)); 149 | return WEXITSTATUS(status); 150 | #endif 151 | } 152 | -------------------------------------------------------------------------------- /sampling_server/src/engine/helper_multiprocess.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-2018 NVIDIA Corporation. All rights reserved. 3 | * 4 | * Please refer to the NVIDIA end user license agreement (EULA) associated 5 | * with this source code for terms and conditions that govern your use of 6 | * this software. Any use, reproduction, disclosure, or distribution of 7 | * this software and related documentation outside the terms of the EULA 8 | * is strictly prohibited. 9 | * 10 | */ 11 | 12 | #ifndef HELPER_MULTIPROCESS_H 13 | #define HELPER_MULTIPROCESS_H 14 | 15 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 16 | #ifndef WIN32_LEAN_AND_MEAN 17 | #define WIN32_LEAN_AND_MEAN 18 | #endif 19 | #include 20 | #else 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #endif 27 | 28 | typedef struct sharedMemoryInfo_st { 29 | void *addr; 30 | size_t size; 31 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 32 | HANDLE shmHandle; 33 | #else 34 | int shmFd; 35 | #endif 36 | } sharedMemoryInfo; 37 | 38 | int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info); 39 | 40 | int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info); 41 | 42 | void sharedMemoryClose(sharedMemoryInfo *info); 43 | 44 | 45 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 46 | typedef PROCESS_INFORMATION Process; 47 | #else 48 | typedef pid_t Process; 49 | #endif 50 | 51 | int spawnProcess(Process *process, const char *app, char * const *args); 52 | 53 | int waitProcess(Process *process); 54 | 55 | #endif // HELPER_MULTIPROCESS_H 56 | -------------------------------------------------------------------------------- /sampling_server/src/engine/ipc_service.h: -------------------------------------------------------------------------------- 1 | #ifndef IPC_ENV_H 2 | #define IPC_ENV_H 3 | #include 4 | #include "buildinfo.h" 5 | 6 | class IPCEnv { 7 | public: 8 | virtual void Coordinate(BuildInfo* info) = 0; 9 | virtual int32_t GetMaxStep() = 0; 10 | 11 | virtual void InitializeSamplesBuffer(int32_t batch_size, int32_t num_ids, int32_t feature_dim, int32_t device_id, int32_t pipeline_depth) = 0; 12 | virtual void InitializeFeaturesBuffer(int32_t batch_size, int32_t num_ids, int32_t feature_dim, int32_t device_id, int32_t pipeline_depth) = 0; 13 | 14 | virtual int32_t GetRawBatchsize() = 0; 15 | virtual int32_t GetLocalBatchId(int32_t global_batch_id) = 0; 16 | virtual int32_t GetCurrentBatchsize(int32_t dev_id, int32_t current_mode) = 0; 17 | virtual int32_t GetCurrentMode(int32_t global_batch_id) = 0; 18 | 19 | virtual int32_t* GetIds(int32_t dev_id, int32_t current_pipe) = 0; 20 | virtual float* GetFloatFeatures(int32_t dev_id, int32_t current_pipe) = 0; 21 | virtual int32_t* GetLabels(int32_t dev_id, int32_t current_pipe) = 0; 22 | virtual int32_t* GetAggSrc(int32_t dev_id, int32_t current_pipe) = 0; 23 | virtual int32_t* GetAggDst(int32_t dev_id, int32_t current_pipe) = 0; 24 | virtual int32_t* GetNodeCounter(int32_t dev_id, int32_t current_pipe) = 0; 25 | virtual int32_t* GetEdgeCounter(int32_t dev_id, int32_t current_pipe) = 0; 26 | 27 | virtual void IPCPost(int32_t dev_id, int32_t current_pipe) = 0; 28 | virtual void IPCWait(int32_t dev_id, int32_t current_pipe) = 0; 29 | 30 | virtual void Finalize() = 0; 31 | virtual int32_t GetTrainStep() = 0; 32 | 33 | }; 34 | 35 | IPCEnv* NewIPCEnv(int32_t device_count); 36 | 37 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/memorypool.cu: -------------------------------------------------------------------------------- 1 | #include "memorypool.cuh" 2 | 3 | MemoryPool::MemoryPool(int32_t pipeline_depth){ 4 | pipeline_depth_ = pipeline_depth; 5 | current_pipe_ = 0; 6 | sampled_ids_.resize(pipeline_depth_); 7 | labels_.resize(pipeline_depth_); 8 | float_features_.resize(pipeline_depth_); 9 | agg_dst_off_.resize(pipeline_depth_); 10 | agg_src_off_.resize(pipeline_depth_); 11 | node_counter_.resize(pipeline_depth_); 12 | edge_counter_.resize(pipeline_depth_); 13 | } 14 | -------------------------------------------------------------------------------- /sampling_server/src/engine/memorypool.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef GPU_MEMORY_POOL 3 | #define GPU_MEMORY_POOL 4 | #include 5 | #include 6 | #include 7 | 8 | // Macro for checking cuda errors following a cuda launch or api call 9 | #define cudaCheckError() \ 10 | { \ 11 | cudaError_t e = cudaGetLastError(); \ 12 | if (e != cudaSuccess) { \ 13 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 14 | cudaGetErrorString(e)); \ 15 | exit(EXIT_FAILURE); \ 16 | } \ 17 | } 18 | 19 | 20 | class MemoryPool { 21 | public: 22 | MemoryPool(int32_t pipeline_depth); 23 | 24 | int32_t GetOpId(){//only used by sampler 25 | return op_id_; 26 | } 27 | 28 | int32_t GetIter(){ 29 | return iter_; 30 | } 31 | 32 | int32_t GetCurrentMode(){ 33 | return mode_; 34 | } 35 | 36 | float* GetFloatFeatures(){ 37 | return float_features_[current_pipe_]; 38 | } 39 | 40 | int32_t* GetCacheSearchBuffer(){ 41 | return cache_search_buffer_; 42 | } 43 | 44 | int32_t* GetLabels(){ 45 | return labels_[current_pipe_]; 46 | } 47 | 48 | uint32_t* GetAccessedMap(){ 49 | return accessed_map_; 50 | } 51 | 52 | int32_t* GetPositionMap(){ 53 | return position_map_; 54 | } 55 | 56 | int32_t* GetNodeCounter(){ 57 | return node_counter_[current_pipe_]; 58 | } 59 | 60 | int32_t* GetEdgeCounter(){ 61 | return edge_counter_[current_pipe_]; 62 | } 63 | 64 | int32_t* GetSampledIds(){ 65 | return sampled_ids_[current_pipe_]; 66 | } 67 | 68 | int32_t* GetAggSrcId(){ 69 | return agg_src_ids_; 70 | } 71 | 72 | int32_t* GetAggDstId(){ 73 | return agg_dst_ids_; 74 | } 75 | 76 | int32_t* GetAggSrcOf(){ 77 | return agg_src_off_[current_pipe_]; 78 | } 79 | 80 | int32_t* GetAggDstOf(){ 81 | return agg_dst_off_[current_pipe_]; 82 | } 83 | 84 | int32_t* GetTmpSrcOf(){ 85 | return tmp_src_of_; 86 | } 87 | 88 | int32_t* GetTmpDstOf(){ 89 | return tmp_dst_of_; 90 | } 91 | 92 | void* GetTempStorage(){ 93 | return temp_storage_; 94 | } 95 | 96 | int32_t* GetTmpPartIdx(){ 97 | return tmp_part_ind_; 98 | } 99 | 100 | int32_t* GetTmpPartOff(){ 101 | return tmp_part_off_; 102 | } 103 | 104 | 105 | void SetFloatFeatures(float* float_features, int32_t current_pipe){ 106 | float_features_[current_pipe] = float_features; 107 | } 108 | 109 | void SetCacheSearchBuffer(int32_t* cache_search_buffer){ 110 | cache_search_buffer_ = cache_search_buffer; 111 | } 112 | 113 | void SetLabels(int32_t* labels, int32_t current_pipe){ 114 | labels_[current_pipe] = labels; 115 | } 116 | 117 | void SetAccessedMap(uint32_t* accessed_map){ 118 | accessed_map_ = accessed_map; 119 | } 120 | 121 | void SetPositionMap(int32_t* position_map){ 122 | position_map_ = position_map; 123 | } 124 | 125 | void SetNodeCounter(int32_t* node_counter, int32_t current_pipe){ 126 | node_counter_[current_pipe] = node_counter; 127 | } 128 | 129 | void SetEdgeCounter(int32_t* edge_counter, int32_t current_pipe){ 130 | edge_counter_[current_pipe] = edge_counter; 131 | } 132 | 133 | void SetSampledIds(int32_t* sampled_ids, int32_t current_pipe){ 134 | sampled_ids_[current_pipe] = sampled_ids; 135 | } 136 | 137 | void SetAggSrcId(int32_t* agg_src_ids){ 138 | agg_src_ids_ = agg_src_ids; 139 | } 140 | 141 | void SetAggDstId(int32_t* agg_dst_ids){ 142 | agg_dst_ids_ = agg_dst_ids; 143 | } 144 | 145 | void SetAggSrcOf(int32_t* agg_src_off, int32_t current_pipe){ 146 | agg_src_off_[current_pipe] = agg_src_off; 147 | } 148 | 149 | void SetAggDstOf(int32_t* agg_dst_off, int32_t current_pipe){ 150 | agg_dst_off_[current_pipe] = agg_dst_off; 151 | } 152 | 153 | void SetTmpSrcOf(int32_t* tmp_src_of){ 154 | tmp_src_of_ = tmp_src_of; 155 | } 156 | 157 | void SetTmpDstOf(int32_t* tmp_dst_of){ 158 | tmp_dst_of_ = tmp_dst_of; 159 | } 160 | 161 | void SetTempStorage(void* temp_storage){ 162 | temp_storage_ = temp_storage; 163 | } 164 | 165 | void SetTmpPartIdx(int32_t* tmp_part_ind){ 166 | tmp_part_ind_ = tmp_part_ind; 167 | } 168 | 169 | void SetTmpPartOff(int32_t* tmp_part_off){ 170 | tmp_part_off_ = tmp_part_off; 171 | } 172 | 173 | void SetOpId(int32_t op_id){ 174 | op_id_ = op_id; 175 | } 176 | 177 | void SetCurrentPipe(int32_t current_pipe){ 178 | current_pipe_ = current_pipe; 179 | } 180 | 181 | void SetCurrentMode(int32_t mode){ 182 | mode_ = mode; 183 | } 184 | 185 | void SetIter(int32_t iter){ 186 | iter_ = iter; 187 | } 188 | 189 | void Finalize() { 190 | cudaFree(cache_search_buffer_); 191 | cudaFree(accessed_map_); 192 | cudaFree(position_map_); 193 | cudaFree(agg_src_ids_); 194 | cudaFree(agg_dst_ids_); 195 | } 196 | 197 | private: 198 | int32_t iter_; 199 | int32_t mode_; 200 | int32_t op_id_; 201 | int32_t* cache_search_buffer_; 202 | uint32_t* accessed_map_; 203 | int32_t* position_map_; 204 | int32_t* agg_src_ids_; 205 | int32_t* agg_dst_ids_; 206 | int32_t* tmp_src_of_; 207 | int32_t* tmp_dst_of_; 208 | void* temp_storage_; 209 | int32_t* tmp_part_ind_; 210 | int32_t* tmp_part_off_; 211 | 212 | int32_t pipeline_depth_; 213 | int32_t current_pipe_; 214 | std::vector float_features_; 215 | std::vector labels_; 216 | std::vector node_counter_; 217 | std::vector edge_counter_; 218 | std::vector sampled_ids_; 219 | std::vector agg_src_off_; 220 | std::vector agg_dst_off_; 221 | }; 222 | 223 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/monitor.cuh: -------------------------------------------------------------------------------- 1 | #ifndef MONITOR_H 2 | #define MONITOR_H 3 | 4 | #include 5 | #ifdef _MSC_VER 6 | #include 7 | #include "windows/windriver.h" 8 | #else 9 | #include 10 | #include 11 | #endif 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | // #include 19 | #include 20 | #include "./pcm/src/pcm-pcie.h" 21 | // #include "cpucounters.h" 22 | 23 | #define PCM_DELAY_DEFAULT 1.0 // in seconds 24 | #define PCM_DELAY_MIN 0.015 // 15 milliseconds is practical on most modern CPUs 25 | 26 | using namespace std; 27 | 28 | bool events_printed = false; 29 | 30 | // #include "zerocp.h" 31 | 32 | IPlatform *IPlatform::getPlatform(PCM *m, bool csv, bool print_bandwidth, bool print_additional_info, uint32 delay) 33 | { 34 | switch (m->getCPUModel()) { 35 | case PCM::ICX: 36 | case PCM::SNOWRIDGE: 37 | return new WhitleyPlatform(m, csv, print_bandwidth, print_additional_info, delay); 38 | case PCM::SKX: 39 | return new PurleyPlatform(m, csv, print_bandwidth, print_additional_info, delay); 40 | case PCM::BDX_DE: 41 | case PCM::BDX: 42 | case PCM::KNL: 43 | case PCM::HASWELLX: 44 | return new GrantleyPlatform(m, csv, print_bandwidth, print_additional_info, delay); 45 | case PCM::IVYTOWN: 46 | case PCM::JAKETOWN: 47 | return new BromolowPlatform(m, csv, print_bandwidth, print_additional_info, delay); 48 | default: 49 | return NULL; 50 | } 51 | } 52 | 53 | class PCM_Monitor { 54 | public: 55 | void Init(); 56 | void Start(); 57 | void Stop(); 58 | void Print(); 59 | std::vector GetCounter(); 60 | 61 | private: 62 | std::vector counter_; 63 | IPlatform* platform_; 64 | vector *eventGroups_; 65 | PCM * m_; 66 | }; 67 | 68 | // #endif 69 | 70 | void PCM_Monitor::Init(){ 71 | 72 | set_signal_handlers(); 73 | 74 | cerr << "\n"; 75 | cerr << " Intel(r) Performance Counter Monitor: PCIe Bandwidth Monitoring Utility \n"; 76 | cerr << " This utility measures PCIe bandwidth in real-time\n"; 77 | cerr << "\n"; 78 | 79 | // double delay = 1.0; 80 | bool csv = false; 81 | bool print_bandwidth = true; 82 | bool print_additional_info = false; 83 | // char * sysCmd = NULL; 84 | // char ** sysArgv = NULL; 85 | MainLoop mainLoop; 86 | 87 | m_= PCM::getInstance(); 88 | 89 | platform_ = IPlatform::getPlatform(m_, csv, print_bandwidth, 90 | print_additional_info, 1); // FIXME: do we support only integer delay? ; lgtm [cpp/fixme-comment] 91 | 92 | if (!platform_) 93 | { 94 | print_cpu_details(); 95 | cerr << "Jaketown, Ivytown, Haswell, Broadwell-DE Server CPU is required for this tool! Program aborted\n"; 96 | exit(EXIT_FAILURE); 97 | } 98 | 99 | eventGroups_= platform_->getEventGroups(); 100 | 101 | platform_->cleanup(); 102 | MySleepMs(uint(1000)); 103 | } 104 | 105 | void PCM_Monitor::Start(){ 106 | printf("Start Count PCIe\n"); 107 | for (auto &evGroup : *eventGroups_){ 108 | m_->programPCIeEventGroup(evGroup); 109 | platform_->getEventGroup(evGroup, 0); 110 | } 111 | for (auto &evGroup : *eventGroups_){ 112 | m_->programPCIeEventGroup(evGroup); 113 | } 114 | } 115 | 116 | void PCM_Monitor::Stop(){ 117 | for (auto &evGroup : *eventGroups_){ 118 | platform_->getEventGroup(evGroup, 1); 119 | } 120 | printf("Stop Count PCIe\n"); 121 | } 122 | 123 | void PCM_Monitor::Print(){ 124 | platform_->printHeader(); 125 | 126 | platform_->printEvents(); 127 | 128 | platform_->printAggregatedEvents(); 129 | } 130 | 131 | std::vector PCM_Monitor::GetCounter(){ 132 | 133 | return platform_->GetCounter(); 134 | } 135 | 136 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/operator.cu: -------------------------------------------------------------------------------- 1 | #include "operator.h" 2 | #include "operator_impl.cuh" 3 | #include "storage_management.cuh" 4 | #include "graph_storage.cuh" 5 | #include "feature_storage.cuh" 6 | #include "ipc_service.h" 7 | #include "cache.cuh" 8 | #include "memorypool.cuh" 9 | 10 | class BatchGenerateOP : public Operator { 11 | public: 12 | BatchGenerateOP(int op_id){ 13 | op_id_ = op_id; 14 | } 15 | void run(OpParams* params) override { 16 | FeatureStorage* feature = (FeatureStorage*)(params->feature); 17 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 18 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 19 | IPCEnv* env = (IPCEnv*)(params->env); 20 | int32_t device_id = params->device_id; 21 | int32_t mode = memorypool->GetCurrentMode(); 22 | int32_t iter = memorypool->GetIter(); 23 | int32_t batch_size = env->GetCurrentBatchsize(device_id, mode); 24 | bool is_presc = params->is_presc; 25 | int32_t hop_num = params->hop_num; 26 | 27 | BatchGenerate(params->stream, feature, cache, memorypool, batch_size, iter, device_id, device_id, mode, is_presc, hop_num); 28 | cudaEventRecord(((params->event)), ((params->stream))); 29 | cudaCheckError(); 30 | } 31 | private: 32 | int op_id_; 33 | }; 34 | 35 | Operator* NewBatchGenerateOP(int op_id){ 36 | return new BatchGenerateOP(op_id); 37 | } 38 | 39 | class RandomSampleOP : public Operator { 40 | public: 41 | RandomSampleOP(int op_id){ 42 | op_id_ = op_id; 43 | } 44 | void run(OpParams* params) override { 45 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 46 | GraphStorage* graph = (GraphStorage*)(params->graph); 47 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 48 | bool is_presc = params->is_presc; 49 | int32_t count = params->neighbor_count; 50 | int32_t device_id = params->device_id; 51 | 52 | RandomSample(params->stream, graph, cache, memorypool, count, device_id, op_id_, is_presc); 53 | cudaEventRecord(((params->event)), ((params->stream))); 54 | cudaCheckError(); 55 | } 56 | private: 57 | int op_id_; 58 | }; 59 | 60 | Operator* NewRandomSampleOP(int op_id){ 61 | return new RandomSampleOP(op_id); 62 | } 63 | 64 | class CacheLookupOP : public Operator { 65 | public: 66 | CacheLookupOP(int op_id){ 67 | op_id_ = op_id; 68 | } 69 | void run(OpParams* params) override { 70 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 71 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 72 | int32_t device_id = params->device_id; 73 | 74 | FeatureCacheLookup(params->stream, cache, memorypool, op_id_, device_id); 75 | cudaEventRecord(((params->event)), ((params->stream))); 76 | cudaCheckError(); 77 | } 78 | private: 79 | int op_id_; 80 | }; 81 | 82 | Operator* NewCacheLookupOP(int op_id){ 83 | return new CacheLookupOP(op_id); 84 | } 85 | 86 | class SSDIOSubmitOP : public Operator { 87 | public: 88 | SSDIOSubmitOP(int op_id) { 89 | op_id_ = op_id; 90 | } 91 | void run(OpParams* params) override { 92 | FeatureStorage* feature = (FeatureStorage*)(params->feature); 93 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 94 | int32_t device_id = params->device_id; 95 | 96 | IOSubmit(params->stream, feature, memorypool, op_id_, device_id); 97 | cudaEventRecord(((params->event)), ((params->stream))); 98 | cudaCheckError(); 99 | } 100 | private: 101 | int op_id_; 102 | }; 103 | 104 | Operator* NewSSDIOSubmitOP(int op_id){ 105 | return new SSDIOSubmitOP(op_id); 106 | } 107 | 108 | class SSDIOCompleteOP : public Operator { 109 | public: 110 | SSDIOCompleteOP(int op_id){ 111 | op_id_ = op_id; 112 | } 113 | void run(OpParams* params) override { 114 | FeatureStorage* feature = (FeatureStorage*)(params->feature); 115 | UnifiedCache* cache = (UnifiedCache*)(params->cache); 116 | MemoryPool* memorypool = (MemoryPool*)(params->memorypool); 117 | int mode = memorypool->GetCurrentMode(); 118 | int32_t device_id = params->device_id; 119 | bool is_presc = params->is_presc; 120 | IOComplete(params->stream, feature, cache, memorypool, device_id, mode, is_presc); 121 | cudaEventRecord(((params->event)), ((params->stream))); 122 | cudaCheckError(); 123 | } 124 | private: 125 | int op_id_; 126 | }; 127 | 128 | Operator* NewSSDIOCompleteOP(int op_id){ 129 | return new SSDIOCompleteOP(op_id); 130 | } 131 | -------------------------------------------------------------------------------- /sampling_server/src/engine/operator.h: -------------------------------------------------------------------------------- 1 | #ifndef OPERATOR_H 2 | #define OPERATOR_H 3 | 4 | struct OpParams { 5 | int device_id; 6 | cudaStream_t stream; 7 | cudaEvent_t event; 8 | void* memorypool; 9 | void* cache; 10 | void* graph; 11 | void* feature; 12 | void* env; 13 | int neighbor_count; 14 | bool is_presc; 15 | bool in_memory; 16 | int hop_num; 17 | }; 18 | 19 | class Operator { 20 | public: 21 | virtual void run(OpParams* params) = 0; 22 | }; 23 | 24 | Operator* NewBatchGenerateOP(int op_id); 25 | Operator* NewRandomSampleOP(int op_id); 26 | Operator* NewCacheLookupOP(int op_id); 27 | Operator* NewSSDIOSubmitOP(int op_id); 28 | Operator* NewSSDIOCompleteOP(int op_id); 29 | 30 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/operator_impl.cuh: -------------------------------------------------------------------------------- 1 | #ifndef OPERATOR_IMPL_H 2 | #define OPERATOR_IMPL_H 3 | #include 4 | #include 5 | 6 | #include "memorypool.cuh" 7 | #include "graph_storage.cuh" 8 | #include "feature_storage.cuh" 9 | #include "cache.cuh" 10 | 11 | extern "C" 12 | void BatchGenerate( 13 | cudaStream_t strm_hdl, 14 | FeatureStorage* feature, 15 | UnifiedCache* cache, 16 | MemoryPool* memorypool, 17 | int32_t batch_size, 18 | int32_t counter, 19 | int32_t part_id, 20 | int32_t dev_id, 21 | int32_t mode, 22 | bool is_presc, 23 | int32_t hop_num 24 | ); 25 | 26 | extern "C" 27 | void RandomSample( 28 | cudaStream_t strm_hdl, 29 | GraphStorage* graph, 30 | UnifiedCache* cache, 31 | MemoryPool* memorypool, 32 | int32_t count, 33 | int32_t dev_id, 34 | int32_t op_id, 35 | bool is_presc 36 | ); 37 | 38 | extern "C" 39 | void FeatureCacheLookup( 40 | cudaStream_t strm_hdl, 41 | UnifiedCache* cache, 42 | MemoryPool* memorypool, 43 | int32_t op_id, 44 | int32_t dev_id 45 | ); 46 | 47 | extern "C" 48 | void IOSubmit( 49 | cudaStream_t strm_hdl, 50 | FeatureStorage* feature, 51 | MemoryPool* memorypool, 52 | int32_t op_id, 53 | int32_t dev_id 54 | ); 55 | 56 | extern "C" 57 | void IOComplete( 58 | cudaStream_t strm_hdl, 59 | FeatureStorage* feature, 60 | UnifiedCache* cache, 61 | MemoryPool* memorypool, 62 | int32_t dev_id, 63 | int32_t mode, 64 | bool is_presc 65 | ); 66 | 67 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/server.h: -------------------------------------------------------------------------------- 1 | #ifndef SERVER_H 2 | #define SERVER_H 3 | #include 4 | 5 | struct RunnerParams { 6 | int device_id; 7 | std::vector fanout; 8 | void* cache; 9 | void* graph; 10 | void* feature; 11 | void* env; 12 | int global_batch_id; 13 | bool in_memory; 14 | }; 15 | 16 | class Server { 17 | public: 18 | virtual void Initialize(int global_shard_count, std::vector fanout) = 0; 19 | virtual void PreSc(int cache_agg_mode) = 0; 20 | virtual void Run() = 0; 21 | virtual void Finalize() = 0; 22 | }; 23 | Server* NewGPUServer(); 24 | 25 | class Runner { 26 | public: 27 | virtual void Initialize(RunnerParams* params) = 0; 28 | virtual void InitializeFeaturesBuffer(RunnerParams* params) = 0; 29 | virtual void RunPreSc(RunnerParams* params) = 0; 30 | virtual void RunOnce(RunnerParams* params) = 0; 31 | virtual void Finalize(RunnerParams* params) = 0; 32 | }; 33 | Runner* NewGPURunner(); 34 | 35 | #endif -------------------------------------------------------------------------------- /sampling_server/src/engine/server_imp.cuh: -------------------------------------------------------------------------------- 1 | 2 | extern "C" 3 | void* d_alloc_space(int64_t num_bytes) { 4 | void *ret; 5 | cudaMalloc(&ret, num_bytes); 6 | cudaCheckError(); 7 | return ret; 8 | } 9 | 10 | extern "C" 11 | void* d_alloc_space_managed(unsigned int num_bytes) { 12 | void *ret; 13 | cudaMallocManaged(&ret, num_bytes); 14 | cudaCheckError(); 15 | return ret; 16 | } 17 | 18 | extern "C" 19 | void d_copy_2_h(void* h_ptr, void* d_ptr, unsigned int num_bytes){ 20 | cudaMemcpy(h_ptr, d_ptr, num_bytes, cudaMemcpyDeviceToHost); 21 | cudaCheckError(); 22 | } 23 | 24 | 25 | extern "C" 26 | void SetGPUDevice(int32_t shard_id){ 27 | cudaSetDevice(shard_id); 28 | cudaCheckError(); 29 | } 30 | 31 | extern "C" 32 | int32_t GetGPUDevice(){ 33 | int32_t dev_id; 34 | cudaGetDevice(&dev_id); 35 | return dev_id; 36 | } 37 | 38 | extern "C" 39 | void d_free_space(void* d_ptr){ 40 | cudaFree(d_ptr); 41 | } 42 | 43 | 44 | extern "C" 45 | void* host_alloc_space(unsigned int num_bytes) { 46 | void* host_ptr; 47 | void* ret; 48 | cudaHostAlloc(&host_ptr, num_bytes, cudaHostAllocMapped); 49 | cudaHostGetDevicePointer(&ret, host_ptr, 0); 50 | cudaCheckError(); 51 | return ret; 52 | } -------------------------------------------------------------------------------- /sampling_server/src/include/buildinfo.h: -------------------------------------------------------------------------------- 1 | #ifndef BUILD_INFO_H 2 | #define BUILD_INFO_H 3 | #include 4 | #include 5 | 6 | struct BuildInfo{ 7 | //device `` 8 | std::vector shard_to_partition; 9 | std::vector shard_to_device; 10 | int32_t partition_count; 11 | //training set 12 | std::vector training_set_num; 13 | std::vector> training_set_ids; 14 | std::vector> training_labels; 15 | //validation set 16 | std::vector validation_set_num; 17 | std::vector> validation_set_ids; 18 | std::vector> validation_labels; 19 | //testing set 20 | std::vector testing_set_num; 21 | std::vector> testing_set_ids; 22 | std::vector> testing_labels; 23 | //features 24 | int32_t total_num_nodes; 25 | int32_t float_feature_len; 26 | float* host_float_feature;//allocated by cudaHostAlloc 27 | 28 | //bam params 29 | uint32_t cudaDevice; 30 | uint64_t cudaDeviceId; 31 | const char* blockDevicePath; 32 | const char* controllerPath; 33 | uint64_t controllerId; 34 | uint32_t adapter; 35 | uint32_t segmentId; 36 | uint32_t nvmNamespace; 37 | bool doubleBuffered; 38 | size_t numReqs; 39 | size_t numPages; 40 | size_t startBlock; 41 | bool stats; 42 | const char* output; 43 | size_t numThreads; 44 | uint32_t domain; 45 | uint32_t bus; 46 | uint32_t devfn; 47 | uint32_t n_ctrls; 48 | size_t blkSize; 49 | size_t queueDepth; 50 | size_t numQueues; 51 | size_t pageSize; 52 | uint64_t numElems; 53 | bool random; 54 | uint64_t ssdtype; 55 | 56 | //csr 57 | // std::vector> csr_node_index; 58 | // std::vector> csr_dst_node_ids; 59 | int64_t* csr_node_index; 60 | int32_t* csr_dst_node_ids; 61 | // std::vector partition_index; 62 | // std::vector partition_offset; 63 | int64_t cache_edge_num; 64 | int64_t total_edge_num; 65 | //train 66 | int32_t epoch; 67 | int32_t raw_batch_size; 68 | 69 | //iostack 70 | int32_t num_ssd; 71 | int32_t num_queues_per_ssd; 72 | }; 73 | 74 | #endif -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap.h: -------------------------------------------------------------------------------- 1 | #ifndef HASHMAPH 2 | #define HASHMAPH 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | // #include 9 | // #include 10 | #include 11 | // #include 12 | #include 13 | #include 14 | 15 | #endif -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB_RECURSE bght_src 2 | "*.h" 3 | "*.hpp" 4 | "*.cuh" 5 | "*.cu") 6 | set(SOURCE_LIST ${bght_src}) 7 | target_sources(bght INTERFACE ${bght_src}) 8 | target_include_directories(bght INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}") -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/benchmark_helpers.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | namespace benchmark { 29 | 30 | // min_key and max_key are exclusive 31 | template 32 | void generate_uniform_unique_keys( 33 | std::vector& keys, 34 | std::size_t num_keys, 35 | key_type min_key = std::numeric_limits::min() + 1, 36 | key_type max_key = std::numeric_limits::max() - 1, 37 | unsigned seed = 1, 38 | bool cache = false) { 39 | keys.resize(num_keys); 40 | std::string dataset_dir = "dataset"; 41 | std::string dataset_name = std::to_string(num_keys) + "_" + std::to_string(seed); 42 | std::string dataset_path = dataset_dir + "/" + dataset_name; 43 | if (cache) { 44 | if (std::filesystem::exists(dataset_dir)) { 45 | if (std::filesystem::exists(dataset_path)) { 46 | std::cout << "Reading cached keys.." << std::endl; 47 | std::ifstream dataset(dataset_path, std::ios::binary); 48 | dataset.read((char*)keys.data(), sizeof(key_type) * num_keys); 49 | dataset.close(); 50 | return; 51 | } 52 | } else { 53 | std::filesystem::create_directory(dataset_dir); 54 | } 55 | } 56 | std::random_device rd; 57 | std::mt19937 rng(seed); 58 | std::uniform_int_distribution uni(min_key, max_key); 59 | std::unordered_set unique_keys; 60 | while (unique_keys.size() < num_keys) { 61 | unique_keys.insert(uni(rng)); 62 | } 63 | std::copy(unique_keys.cbegin(), unique_keys.cend(), keys.begin()); 64 | 65 | if (cache) { 66 | std::cout << "Caching.." << std::endl; 67 | std::ofstream dataset(dataset_path, std::ios::binary); 68 | dataset.write((char*)keys.data(), sizeof(key_type) * num_keys); 69 | dataset.close(); 70 | } 71 | } 72 | // 73 | // template 74 | // uint64_t validate(const std::vector& h_keys, 75 | // const std::vector& h_find_keys, 76 | // const thrust::device_vector& d_results, 77 | // const uint32_t& num_keys, 78 | // const value_type& sentinel_value, 79 | // function to_value, 80 | // float exist_ratio = 1.0f) { 81 | // uint64_t num_errors = 0; 82 | // uint64_t max_errors = 10; 83 | // using pair_type = bght::pair_type; 84 | // auto h_results = thrust::host_vector(d_results); 85 | // std::unordered_set cpu_ref_set; 86 | // if (exist_ratio != 1.0f) { 87 | // cpu_ref_set.insert(h_keys.begin(), h_keys.begin() + num_keys); 88 | // } 89 | // for (size_t i = 0; i < num_keys; i++) { 90 | // key_type query_key = h_find_keys[i]; 91 | // value_type query_result = h_results[i]; 92 | // value_type expected_result = to_value(query_key); 93 | // if (exist_ratio != 1.0f) { 94 | // auto expected_result_ptr = cpu_ref_set.find(query_key); 95 | // if (expected_result_ptr == cpu_ref_set.end()) { 96 | // expected_result = sentinel_value; 97 | // } 98 | // } 99 | // 100 | // if (query_result != expected_result) { 101 | // std::string message = std::string("query_key = ") + std::to_string(query_key) + 102 | // std::string(", expected: ") + 103 | // std::to_string(expected_result) + std::string(", found: ") + 104 | // std::to_string(query_result); 105 | // std::cout << message << std::endl; 106 | // num_errors++; 107 | // if (num_errors == max_errors) 108 | // break; 109 | // } 110 | // } 111 | // return num_errors; 112 | //} 113 | 114 | template 115 | void prep_experiment_find_with_exist_ratio(float exist_ratio, 116 | std::size_t num_keys, 117 | const std::vector& keys, 118 | std::vector& find_keys, 119 | key_type* d_find_keys) { 120 | // Choose the keys over which we will search based on the 121 | // exist_ratio. Recall that keys.size() == 2 * num_keys. 122 | assert(num_keys * 2 == keys.size()); 123 | unsigned int end_index = num_keys * (-exist_ratio + 2); 124 | unsigned int start_index = end_index - num_keys; 125 | 126 | static constexpr uint32_t EMPTY_VALUE = 0xFFFFFFFF; 127 | 128 | // Need to copy our range [start_index, end_index) from keys 129 | // into find_keys. 130 | std::fill(find_keys.begin(), find_keys.end(), EMPTY_VALUE); 131 | std::copy(keys.begin() + start_index, keys.begin() + end_index, find_keys.begin()); 132 | cuda_try(cudaMemcpy(d_find_keys, 133 | find_keys.data(), 134 | sizeof(key_type) * find_keys.size(), 135 | cudaMemcpyHostToDevice)); 136 | } 137 | 138 | template 139 | void prep_experiment_find_with_exist_ratio(float exist_ratio, 140 | std::size_t num_keys, 141 | const thrust::device_vector& keys, 142 | thrust::device_vector& find_keys) { 143 | // Choose the keys over which we will search based on the 144 | // exist_ratio. Recall that keys.size() == 2 * num_keys. 145 | assert(num_keys * 2 == keys.size()); 146 | unsigned int end_index = num_keys * (-exist_ratio + 2); 147 | unsigned int start_index = end_index - num_keys; 148 | 149 | static constexpr uint32_t EMPTY_VALUE = 0xFFFFFFFF; 150 | 151 | // Need to copy our range [start_index, end_index) from keys 152 | // into find_keys. 153 | thrust::fill(thrust::device, find_keys.begin(), find_keys.end(), EMPTY_VALUE); 154 | thrust::copy(thrust::device, 155 | keys.begin() + start_index, 156 | keys.begin() + end_index, 157 | find_keys.begin()); 158 | } 159 | 160 | } // namespace benchmark 161 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/cht.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | namespace bght { 28 | 29 | /** 30 | * @brief CHT CHT (cuckoo hash table) is an associative static GPU hash table 31 | * that contains key-value pairs with unique keys. The hash table is an open addressing 32 | * hash table based on the cuckoo hashing probing scheme (bucket size of one and using 33 | * four hash functions). 34 | * 35 | * @tparam Key Type for the hash map key 36 | * @tparam T Type for the mapped value 37 | * @tparam Hash Unary function object class that defines the hash function. The function 38 | * must have an `initialize_hf` specialization to initialize the hash function using a 39 | * random number generator 40 | * @tparam KeyEqual Binary function object class that compares two keys 41 | * @tparam Allocator The allocator to use for allocating GPU device memory 42 | */ 43 | template , 46 | class KeyEqual = bght::equal_to, 47 | cuda::thread_scope Scope = cuda::thread_scope_device, 48 | class Allocator = bght::cuda_allocator> 49 | struct cht { 50 | using value_type = pair; 51 | using key_type = Key; 52 | using mapped_type = T; 53 | using atomic_pair_type = cuda::atomic; 54 | using allocator_type = Allocator; 55 | using hasher = Hash; 56 | using atomic_pair_allocator_type = 57 | typename std::allocator_traits::rebind_alloc; 58 | using pool_allocator_type = 59 | typename std::allocator_traits::rebind_alloc; 60 | using key_equal = KeyEqual; 61 | 62 | /** 63 | * @brief Constructs the hash table with the specified capacity and uses the specified 64 | * sentinel key and value to define a sentinel pair. 65 | * 66 | * @param capacity The number of slots to use in the hash table 67 | * @param sentinel_key A reserved sentinel key that defines an empty key 68 | * @param sentinel_value A reserved sentinel value that defines an empty value 69 | * @param allocator The allocator to use for allocating GPU device memory 70 | */ 71 | cht(std::size_t capacity, 72 | Key sentinel_key, 73 | T sentinel_value, 74 | Allocator const& allocator = Allocator{}); 75 | 76 | /** 77 | * @brief A shallow-copy constructor 78 | */ 79 | cht(const cht& other); 80 | /** 81 | * @brief Move constructor is currently deleted 82 | */ 83 | cht(cht&&) = delete; 84 | /** 85 | * @brief The assignment operator is currently deleted 86 | */ 87 | cht& operator=(const cht&) = delete; 88 | /** 89 | * @brief The move assignment operator is currently deleted 90 | */ 91 | cht& operator=(cht&&) = delete; 92 | /** 93 | * @brief Destructor that destroys the hash map and deallocate memory if no copies exist 94 | */ 95 | ~cht(); 96 | /** 97 | * @brief Clears the hash map and resets all slots 98 | */ 99 | void clear(); 100 | 101 | /** 102 | * @brief Host-side API for inserting all pairs defined by the input argument iterators. 103 | * All keys in the range must be unique and must not exist in the hash table. 104 | * @tparam InputIt Device-side iterator that can be converted to `value_type`. 105 | * @param first An iterator defining the beginning of the input pairs to insert 106 | * @param last An iterator defining the end of the input pairs to insert 107 | * @param stream A CUDA stream where the insertion operation will take place 108 | * @return A boolean indicating success (true) or failure (false) of the insertion 109 | * operation. 110 | */ 111 | template 112 | bool insert(InputIt first, InputIt last, cudaStream_t stream = 0); 113 | 114 | /** 115 | * @brief Host-side API for finding all keys defined by the input argument iterators. 116 | * @tparam InputIt Device-side iterator that can be converted to `key_type` 117 | * @tparam OutputIt Device-side iterator that can be converted to `mapped_type` 118 | * @param first An iterator defining the beginning of the input keys to find 119 | * @param last An iterator defining the end of the input keys to find 120 | * @param output_begin An iterator defining the beginning of the output buffer to store 121 | * the results into. The size of the buffer must match the number of queries defined by 122 | * the input iterators. 123 | * @param stream A CUDA stream where the insertion operation will take place 124 | */ 125 | template 126 | void find(InputIt first, InputIt last, OutputIt output_begin, cudaStream_t stream = 0); 127 | 128 | /** 129 | * @brief Device-side cooperative insertion API that inserts a single pair into the hash 130 | * map. This function is called by a single thread. 131 | * @param pair A key-value pair to insert into the hash map. 132 | * @return A boolean indicating success (true) or failure (false) of the insertion 133 | * operation. 134 | */ 135 | __device__ bool insert(value_type const& pair); 136 | 137 | /** 138 | * @brief Device-side cooperative find API that finds a single pair into the hash 139 | * map. 140 | * @param key A key to find in the hash map. The key must be the same 141 | * for all threads in the cooperative group tile 142 | * @return The value of the key if it exists in the map or the `sentinel_value` if the 143 | * key does not exist in the hash map 144 | */ 145 | __device__ mapped_type find(key_type const& key); 146 | 147 | /** 148 | * @brief Host-side API to randomize the hash functions used for the probing scheme. 149 | * This can be used when the hash table construction fails. The hash table must be 150 | * cleared after a call to this function. 151 | * @tparam RNG A pseudo-random number generator 152 | * @param rng An instantiation of the pseudo-random number generator 153 | */ 154 | template 155 | void randomize_hash_functions(RNG& rng); 156 | 157 | private: 158 | __device__ void set_build_success(const bool& success) { *d_build_success_ = success; } 159 | 160 | template 161 | friend __global__ void detail::kernels::insert_kernel(InputIt, InputIt, HashMap); 162 | 163 | template 164 | friend __global__ void detail::kernels::find_kernel(InputIt, 165 | InputIt, 166 | OutputIt, 167 | HashMap); 168 | 169 | std::size_t capacity_; 170 | key_type sentinel_key_{}; 171 | mapped_type sentinel_value_{}; 172 | allocator_type allocator_; 173 | atomic_pair_allocator_type atomic_pairs_allocator_; 174 | pool_allocator_type pool_allocator_; 175 | 176 | atomic_pair_type* d_table_{}; 177 | std::shared_ptr table_; 178 | 179 | bool* d_build_success_; 180 | std::shared_ptr build_success_; 181 | 182 | uint32_t max_cuckoo_chains_; 183 | 184 | Hash hf0_; 185 | Hash hf1_; 186 | Hash hf2_; 187 | Hash hf3_; 188 | 189 | std::size_t num_buckets_; 190 | }; 191 | 192 | } // namespace bght 193 | 194 | #include 195 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/cmd.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | std::string str_tolower(const std::string_view s) { 26 | std::string output(s.length(), ' '); 27 | std::transform(s.begin(), s.end(), output.begin(), [](unsigned char c) { 28 | return std::tolower(c); 29 | }); 30 | return output; 31 | } 32 | 33 | // Finds an argument value 34 | // auto arguments = std::vector(argv, argv + argc); 35 | // Example: 36 | // auto k = get_arg_value(arguments, "-flag") 37 | // auto arguments = std::vector(argv, argv + argc); 38 | template 39 | std::optional get_arg_value(const std::vector& arguments, 40 | const char* flag) { 41 | uint32_t first_argument = 1; 42 | for (uint32_t i = first_argument; i < arguments.size(); i++) { 43 | std::string_view argument = std::string_view(arguments[i]); 44 | auto key_start = argument.find_first_not_of("-"); 45 | auto value_start = argument.find("="); 46 | 47 | bool failed = argument.length() == 0; // there is an argument 48 | failed |= key_start == std::string::npos; // it has a - 49 | failed |= value_start == std::string::npos; // it has an = 50 | failed |= key_start > 2; // - or -- at beginning 51 | failed |= (value_start - key_start) == 0; // there is a key 52 | failed |= (argument.length() - value_start) == 1; // = is not last 53 | 54 | if (failed) { 55 | std::cout << "Invalid argument: " << argument << " ignored.\n"; 56 | std::cout << "Use: -flag=value " << std::endl; 57 | std::terminate(); 58 | } 59 | 60 | std::string_view argument_name = argument.substr(key_start, value_start - key_start); 61 | value_start++; // ignore the = 62 | std::string_view argument_value = 63 | argument.substr(value_start, argument.length() - key_start); 64 | 65 | if (argument_name == std::string_view(flag)) { 66 | if constexpr (std::is_same::value) { 67 | return static_cast(std::strtof(argument_value.data(), nullptr)); 68 | } else if constexpr (std::is_same::value) { 69 | return static_cast(std::strtod(argument_value.data(), nullptr)); 70 | } else if constexpr (std::is_same::value) { 71 | return static_cast(std::strtol(argument_value.data(), nullptr, 10)); 72 | } else if constexpr (std::is_same::value) { 73 | return static_cast(std::strtoll(argument_value.data(), nullptr, 10)); 74 | } else if constexpr (std::is_same::value) { 75 | return static_cast(std::strtoul(argument_value.data(), nullptr, 10)); 76 | } else if constexpr (std::is_same::value) { 77 | return static_cast(std::strtoull(argument_value.data(), nullptr, 10)); 78 | } else if constexpr (std::is_same::value) { 79 | return std::string(argument_value); 80 | } else if constexpr (std::is_same::value) { 81 | return str_tolower(argument_value) == "true"; 82 | } else { 83 | std::cout << "Unknown type" << std::endl; 84 | std::terminate(); 85 | } 86 | } 87 | } 88 | return {}; 89 | } 90 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/allocator.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | namespace bght { 21 | template 22 | struct cuda_deleter { 23 | void operator()(T* p) { cuda_try(cudaFree(p)); } 24 | }; 25 | 26 | template 27 | struct cuda_allocator { 28 | typedef std::size_t size_type; 29 | typedef std::ptrdiff_t difference_type; 30 | 31 | typedef T value_type; 32 | typedef T* pointer; 33 | typedef const T* const_pointer; 34 | typedef T& reference; 35 | typedef const T& const_reference; 36 | 37 | template 38 | struct rebind { 39 | typedef cuda_allocator other; 40 | }; 41 | cuda_allocator() = default; 42 | template 43 | constexpr cuda_allocator(const cuda_allocator&) noexcept {} 44 | T* allocate(std::size_t n) { 45 | void* p = nullptr; 46 | cuda_try(cudaMalloc(&p, n * sizeof(T))); 47 | return static_cast(p); 48 | } 49 | void deallocate(T* p, std::size_t n) noexcept { cuda_try(cudaFree(p)); } 50 | }; 51 | } // namespace bght 52 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/benchmark_metrics.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | #ifdef COUNT_PROBES 20 | __device__ __managed__ uint32_t global_probes_count = 0; 21 | #define INCREMENT_PROBES_IN_TILE \ 22 | if (tile.thread_rank() == 0) \ 23 | atomicAdd(&global_probes_count, 1); 24 | #define INCREMENT_PROBES atomicAdd(&global_probes_count, 1); 25 | namespace bght { 26 | // uint32_t get_num_probes() { 27 | // cudaDeviceSynchronize(); 28 | // auto count = global_probes_count; 29 | // global_probes_count = 0; 30 | // cudaDeviceSynchronize(); 31 | // return count; 32 | // } 33 | } // namespace bght 34 | #else 35 | #define INCREMENT_PROBES_IN_TILE 36 | #define INCREMENT_PROBES 37 | namespace bght { 38 | // uint32_t get_num_probes() { 39 | // return 0; 40 | // } 41 | } // namespace bght 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/bucket.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | 21 | namespace bght { 22 | namespace detail { 23 | template 24 | struct bucket { 25 | bucket() = delete; 26 | DEVICE_QUALIFIER 27 | bucket(atomic_pair_type* ptr, const tile_type& tile) : ptr_(ptr), tile_(tile) {} 28 | 29 | DEVICE_QUALIFIER 30 | bucket(const bucket& other) : lane_pair_(other.lane_pair_), ptr_(other.ptr_) {} 31 | 32 | DEVICE_QUALIFIER 33 | void load(cuda::memory_order order = cuda::memory_order_seq_cst) { 34 | lane_pair_ = ptr_[tile_.thread_rank()].load(order); 35 | } 36 | DEVICE_QUALIFIER 37 | int compute_load(const pair_type& sentinel) { 38 | auto load_bitmap = tile_.ballot(lane_pair_.first != sentinel.first); 39 | int load = __popc(load_bitmap); 40 | return load; 41 | } 42 | // returns -1 if not found 43 | template 44 | DEVICE_QUALIFIER int find_key_location(const typename pair_type::first_type& key, 45 | const KeyEqual key_equal) { 46 | bool key_exist = key_equal(key, lane_pair_.first); 47 | auto key_exist_bmap = tile_.ballot(key_exist); 48 | int key_lane = __ffs(key_exist_bmap); 49 | return key_lane - 1; 50 | } 51 | DEVICE_QUALIFIER 52 | typename pair_type::second_type get_value_from_lane(int location) { 53 | return tile_.shfl(lane_pair_.second, location); 54 | } 55 | 56 | DEVICE_QUALIFIER 57 | bool weak_cas_at_location(const pair_type& pair, 58 | const int location, 59 | const pair_type& sentinel, 60 | cuda::memory_order success = cuda::memory_order_seq_cst, 61 | cuda::memory_order failure = cuda::memory_order_seq_cst) { 62 | pair_type expected = sentinel; 63 | pair_type desired = pair; 64 | bool cas_success = 65 | ptr_[location].compare_exchange_weak(expected, desired, success, failure); 66 | return cas_success; 67 | } 68 | 69 | DEVICE_QUALIFIER 70 | bool strong_cas_at_location(const pair_type& pair, 71 | const int location, 72 | const pair_type& sentinel, 73 | cuda::memory_order success = cuda::memory_order_seq_cst, 74 | cuda::memory_order failure = cuda::memory_order_seq_cst) { 75 | pair_type expected = sentinel; 76 | pair_type desired = pair; 77 | bool cas_success = 78 | ptr_[location].compare_exchange_strong(expected, desired, success, failure); 79 | return cas_success; 80 | } 81 | 82 | DEVICE_QUALIFIER 83 | pair_type exch_at_location(const pair_type& pair, 84 | const int location, 85 | cuda::memory_order order = cuda::memory_order_seq_cst) { 86 | auto old = ptr_[location].exchange(pair, order); 87 | return old; 88 | } 89 | 90 | private: 91 | pair_type lane_pair_; 92 | atomic_pair_type* ptr_; 93 | tile_type tile_; 94 | }; 95 | } // namespace detail 96 | } // namespace bght 97 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/cuda_helpers.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | namespace bght { 19 | #define _device_ __device__ __forceinline__ 20 | #define _host_device_ __device__ __host__ __forceinline__ 21 | #define _kernel_ __global__ 22 | #define DEVICE_QUALIFIER __device__ inline 23 | namespace detail { 24 | #define cuda_try(call) \ 25 | do { \ 26 | cudaError_t err = call; \ 27 | if (err != cudaSuccess) { \ 28 | printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ 29 | std::terminate(); \ 30 | } \ 31 | } while (0) 32 | 33 | _device_ void cuda_assert(bool expression_result, char* message = nullptr) { 34 | if (!expression_result) { 35 | if (message && (threadIdx.x & 0x1f == 0)) { 36 | printf("assert failed: %s", message); 37 | } 38 | //__trap(); 39 | asm("trap;"); 40 | } 41 | } 42 | } // namespace detail 43 | 44 | // void set_device(int device_id) { 45 | // int device_count; 46 | // cudaGetDeviceCount(&device_count); 47 | // cudaDeviceProp devProp; 48 | // if (device_id < device_count) { 49 | // cudaSetDevice(device_id); 50 | // cudaGetDeviceProperties(&devProp, device_id); 51 | // std::cout << "Device[" << device_id << "]: " << devProp.name << std::endl; 52 | // } else { 53 | // std::cout << "No capable CUDA device found." << std::endl; 54 | // std::terminate(); 55 | // } 56 | // } 57 | } // namespace bght 58 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/hash_functions.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | namespace bght { 19 | template 20 | struct universal_hash { 21 | using key_type = Key; 22 | using result_type = Key; 23 | __host__ __device__ constexpr universal_hash(uint32_t hash_x, uint32_t hash_y) 24 | : hash_x_(hash_x), hash_y_(hash_y) {} 25 | 26 | constexpr result_type __host__ __device__ operator()(const key_type& key) const { 27 | return (((hash_x_ ^ key) + hash_y_) % prime_divisor); 28 | } 29 | 30 | universal_hash(const universal_hash&) = default; 31 | universal_hash() = default; 32 | universal_hash(universal_hash&&) = default; 33 | universal_hash& operator=(universal_hash const&) = default; 34 | universal_hash& operator=(universal_hash&&) = default; 35 | ~universal_hash() = default; 36 | 37 | static constexpr uint32_t prime_divisor = 4294967291u; 38 | 39 | private: 40 | uint32_t hash_x_; 41 | uint32_t hash_y_; 42 | }; 43 | 44 | template 45 | Hash initialize_hf(RNG& rng) { 46 | uint32_t x = rng() % Hash::prime_divisor; 47 | if (x < 1u) { 48 | x = 1; 49 | } 50 | uint32_t y = rng() % Hash::prime_divisor; 51 | return Hash(x, y); 52 | } 53 | } // namespace bght 54 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/kernels.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | 22 | namespace bght { 23 | namespace detail { 24 | namespace kernels { 25 | template 26 | __global__ void tiled_insert_kernel(InputIt first, InputIt last, HashMap map) { 27 | // construct the tile 28 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 29 | auto block = cooperative_groups::this_thread_block(); 30 | auto tile = cooperative_groups::tiled_partition(block); 31 | 32 | auto count = last - first; 33 | if ((thread_id - tile.thread_rank()) >= count) { 34 | return; 35 | } 36 | 37 | bool do_op = false; 38 | typename HashMap::value_type insertion_pair{}; 39 | 40 | // load the input 41 | if (thread_id < count) { 42 | insertion_pair = first[thread_id]; 43 | do_op = true; 44 | } 45 | 46 | bool success = true; 47 | // Do the insertion 48 | auto work_queue = tile.ballot(do_op); 49 | while (work_queue) { 50 | auto cur_rank = __ffs(work_queue) - 1; 51 | auto cur_pair = tile.shfl(insertion_pair, cur_rank); 52 | bool insertion_success = map.insert(cur_pair, tile); 53 | 54 | if (tile.thread_rank() == cur_rank) { 55 | do_op = false; 56 | success = insertion_success; 57 | } 58 | work_queue = tile.ballot(do_op); 59 | } 60 | 61 | if (!tile.all(success)) { 62 | *map.d_build_success_ = false; 63 | } 64 | } 65 | 66 | template 67 | __global__ void tiled_find_kernel(InputIt first, 68 | InputIt last, 69 | OutputIt output_begin, 70 | HashMap map) { 71 | // construct the tile 72 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 73 | auto block = cooperative_groups::this_thread_block(); 74 | auto tile = cooperative_groups::tiled_partition(block); 75 | 76 | auto count = last - first; 77 | if ((thread_id - tile.thread_rank()) >= count) { 78 | return; 79 | } 80 | 81 | bool do_op = false; 82 | typename HashMap::key_type find_key; 83 | typename HashMap::mapped_type result; 84 | 85 | // load the input 86 | if (thread_id < count) { 87 | find_key = first[thread_id]; 88 | do_op = true; 89 | } 90 | 91 | // Do the insertion 92 | auto work_queue = tile.ballot(do_op); 93 | while (work_queue) { 94 | auto cur_rank = __ffs(work_queue) - 1; 95 | auto cur_key = tile.shfl(find_key, cur_rank); 96 | 97 | typename HashMap::mapped_type find_result = map.find(cur_key, tile); 98 | 99 | if (tile.thread_rank() == cur_rank) { 100 | result = find_result; 101 | do_op = false; 102 | } 103 | work_queue = tile.ballot(do_op); 104 | } 105 | 106 | if (thread_id < count) { 107 | output_begin[thread_id] = result; 108 | } 109 | } 110 | 111 | template 112 | __global__ void insert_kernel(InputIt first, InputIt last, HashMap map) { 113 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 114 | auto count = last - first; 115 | 116 | if (thread_id < count) { 117 | auto insertion_pair = first[thread_id]; 118 | bool success = map.insert(insertion_pair); 119 | if (!success) { 120 | map.set_build_success(false); 121 | } 122 | } 123 | } 124 | 125 | template 126 | __global__ void find_kernel(InputIt first, 127 | InputIt last, 128 | OutputIt output_begin, 129 | HashMap map) { 130 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 131 | auto count = last - first; 132 | 133 | if (thread_id < count) { 134 | auto find_key = first[thread_id]; 135 | auto result = map.find(find_key); 136 | output_begin[thread_id] = result; 137 | } 138 | } 139 | 140 | template 141 | __global__ void count_kernel(const InputT count_key, std::size_t* count, HashMap map) { 142 | auto thread_id = threadIdx.x + blockIdx.x * blockDim.x; 143 | typedef cub::BlockReduce BlockReduce; 144 | __shared__ typename BlockReduce::TempStorage temp_storage; 145 | 146 | std::size_t match = 0; 147 | if (thread_id < map.capacity_) { 148 | const auto key = map.d_table_[thread_id].load(cuda::memory_order_relaxed).first; 149 | match = (key == count_key); 150 | } 151 | std::size_t block_num_matches = BlockReduce(temp_storage).Sum(match); 152 | if (threadIdx.x == 0) { 153 | auto sum = atomicAdd((unsigned long long int*)count, 154 | (unsigned long long int)block_num_matches); 155 | } 156 | } 157 | 158 | } // namespace kernels 159 | } // namespace detail 160 | } // namespace bght 161 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/pair.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | 21 | namespace bght { 22 | template () != 0> 23 | struct alignas(detail::pair_alignment()) padded_pair { 24 | using first_type = T1; 25 | using second_type = T2; 26 | T1 first; 27 | T2 second; 28 | padded_pair() = default; 29 | ~padded_pair() = default; 30 | padded_pair(padded_pair const&) = default; 31 | padded_pair(padded_pair&&) = default; 32 | padded_pair& operator=(padded_pair const&) = default; 33 | padded_pair& operator=(padded_pair&&) = default; 34 | 35 | __host__ __device__ inline bool operator==(const padded_pair& rhs) { 36 | return (this->first == rhs.first) && (this->second == rhs.second); 37 | } 38 | __host__ __device__ inline bool operator!=(const padded_pair& rhs) { 39 | return !(*this == rhs); 40 | } 41 | 42 | __host__ __device__ constexpr padded_pair(T1 const& t, T2 const& u) 43 | : first{t}, second{u} {} 44 | }; 45 | 46 | template 47 | struct alignas(detail::pair_alignment()) padded_pair { 48 | using first_type = T1; 49 | using second_type = T2; 50 | T1 first; 51 | T2 second; 52 | 53 | padded_pair() = default; 54 | ~padded_pair() = default; 55 | padded_pair(padded_pair const&) = default; 56 | padded_pair(padded_pair&&) = default; 57 | padded_pair& operator=(padded_pair const&) = default; 58 | padded_pair& operator=(padded_pair&&) = default; 59 | 60 | __host__ __device__ inline bool operator==(const padded_pair& rhs) { 61 | return (this->first == rhs.first) && (this->second == rhs.second); 62 | } 63 | __host__ __device__ inline bool operator!=(const padded_pair& rhs) { 64 | return !(*this == rhs); 65 | } 66 | 67 | __host__ __device__ constexpr padded_pair(T1 const& t, T2 const& u) 68 | : first{t}, second{u} {} 69 | 70 | private: 71 | char padding[detail::padding_size()] = {0}; 72 | }; 73 | 74 | template 75 | using pair = padded_pair; 76 | 77 | template 78 | struct equal_to { 79 | constexpr bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; } 80 | }; 81 | } // namespace bght 82 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/pair_detail.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | namespace bght { 20 | namespace detail { 21 | template 22 | constexpr std::size_t next_alignment() { 23 | constexpr std::size_t n = sizeof(T); 24 | if (n <= 4) 25 | return 4; 26 | if (n <= 8) 27 | return 8; 28 | return 16; 29 | } 30 | constexpr std::size_t next_alignment(std::size_t n) { 31 | if (n <= 4) 32 | return 4; 33 | if (n <= 8) 34 | return 8; 35 | return 16; 36 | } 37 | 38 | template 39 | constexpr std::size_t pair_size() { 40 | return sizeof(T1) + sizeof(T2); 41 | } 42 | 43 | template 44 | constexpr std::size_t pair_alignment() { 45 | return next_alignment(pair_size()); 46 | } 47 | 48 | template 49 | constexpr std::size_t padding_size() { 50 | constexpr auto psz = pair_size(); 51 | constexpr auto apsz = next_alignment(pair_size()); 52 | if (psz > apsz) { 53 | constexpr auto nsz = (1ull + (psz / apsz)) * apsz; 54 | return nsz - psz; 55 | } 56 | return apsz - psz; 57 | } 58 | } // namespace detail 59 | } // namespace bght 60 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/ptx.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | namespace bght { 18 | namespace detail { 19 | namespace bits { 20 | // Bit Field Extract. 21 | __device__ __forceinline__ int bfe(uint32_t src, int num_bits) { 22 | unsigned mask; 23 | asm("bfe.u32 %0, %1, 0, %2;" : "=r"(mask) : "r"(src), "r"(num_bits)); 24 | return mask; 25 | } 26 | 27 | // Find most significant non - sign bit. 28 | // bfind(0) = -1, bfind(1) = 0 29 | __device__ __forceinline__ int bfind(uint32_t src) { 30 | int msb; 31 | asm("bfind.u32 %0, %1;" : "=r"(msb) : "r"(src)); 32 | return msb; 33 | } 34 | __device__ __forceinline__ int bfind(uint64_t src) { 35 | int msb; 36 | asm("bfind.u64 %0, %1;" : "=r"(msb) : "l"(src)); 37 | return msb; 38 | } 39 | }; // namespace bits 40 | } // namespace detail 41 | } // namespace bght 42 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/detail/rng.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | namespace bght { 19 | namespace detail { 20 | struct mars_rng_32 { 21 | uint32_t y; 22 | __host__ __device__ constexpr mars_rng_32() : y(2463534242) {} 23 | constexpr uint32_t __host__ __device__ operator()() { 24 | y ^= (y << 13); 25 | y = (y >> 17); 26 | return (y ^= (y << 5)); 27 | } 28 | }; 29 | } // namespace detail 30 | } // namespace bght 31 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/genzipf.hpp: -------------------------------------------------------------------------------- 1 | //= Author: Kenneth J. Christensen = 2 | //= University of South Florida = 3 | //= WWW: http://www.csee.usf.edu/~christen = 4 | //= Email: christen@csee.usf.edu = 5 | //=-------------------------------------------------------------------------= 6 | 7 | //========================================================================= 8 | //= Multiplicative LCG for generating uniform(0.0, 1.0) random numbers = 9 | //= - x_n = 7^5*x_(n-1)mod(2^31 - 1) = 10 | //= - With x seeded to 1 the 10000th x value should be 1043618065 = 11 | //= - From R. Jain, "The Art of Computer Systems Performance Analysis," = 12 | //= John Wiley & Sons, 1991. (Page 443, Figure 26.2) = 13 | //========================================================================= 14 | #pragma once 15 | #include 16 | #include 17 | 18 | double rand_val(int seed) { 19 | const long a = 16807; // Multiplier 20 | const long m = 2147483647; // Modulus 21 | const long q = 127773; // m div a 22 | const long r = 2836; // m mod a 23 | static long x; // Random int value 24 | long x_div_q; // x divided by q 25 | long x_mod_q; // x modulo q 26 | long x_new; // New x value 27 | 28 | // Set the seed if argument is non-zero and then return zero 29 | if (seed > 0) { 30 | x = seed; 31 | return (0.0); 32 | } 33 | 34 | // RNG using integer arithmetic 35 | x_div_q = x / q; 36 | x_mod_q = x % q; 37 | x_new = (a * x_mod_q) - (r * x_div_q); 38 | if (x_new > 0) 39 | x = x_new; 40 | else 41 | x = x_new + m; 42 | 43 | // Return a random value between 0.0 and 1.0 44 | return ((double)x / m); 45 | } 46 | 47 | uint32_t zipf(double alpha, uint32_t n) { 48 | static bool first = true; // Static first time flag 49 | static double c = 0; // Normalization constant 50 | static double* sum_probs; // Pre-calculated sum of probabilities 51 | double z; // Uniform random number (0 < z < 1) 52 | uint32_t zipf_value; // Computed exponential value to be returned 53 | uint32_t i; // Loop counter 54 | uint32_t low, high, mid; // Binary-search bounds 55 | 56 | // Compute normalization constant on first call only 57 | if (first == true) { 58 | for (i = 1; i <= n; i++) 59 | c = c + (1.0 / pow((double)i, alpha)); 60 | c = 1.0 / c; 61 | 62 | sum_probs = reinterpret_cast(std::malloc((n + 1) * sizeof(*sum_probs))); 63 | sum_probs[0] = 0; 64 | for (i = 1; i <= n; i++) { 65 | sum_probs[i] = sum_probs[i - 1] + c / pow((double)i, alpha); 66 | // std::cout << i << "," << sum_probs[i] << std::endl; 67 | } 68 | first = false; 69 | std::cout << "Computed probabilities" << std::endl; 70 | } 71 | 72 | // Pull a uniform random number (0 < z < 1) 73 | do { 74 | z = rand_val(0); 75 | } while ((z == 0) || (z == 1)); 76 | 77 | // Map z to the value 78 | low = 1; 79 | high = n; 80 | do { 81 | mid = floor((low + high) / 2); 82 | if (sum_probs[mid] >= z && sum_probs[mid - 1] < z) { 83 | zipf_value = mid; 84 | break; 85 | } else if (sum_probs[mid] >= z) { 86 | high = mid - 1; 87 | } else { 88 | low = mid + 1; 89 | } 90 | } while (low <= high); 91 | 92 | // Assert that zipf_value is between 1 and N 93 | assert((zipf_value >= 1) && (zipf_value <= n)); 94 | 95 | return zipf_value; 96 | } -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/gpu_timer.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | 20 | struct gpu_timer { 21 | gpu_timer(cudaStream_t stream = 0) : start_{}, stop_{}, stream_(stream) { 22 | cudaEventCreate(&start_); 23 | cudaEventCreate(&stop_); 24 | } 25 | void start_timer() { cudaEventRecord(start_, stream_); } 26 | void stop_timer() { cudaEventRecord(stop_, stream_); } 27 | float get_elapsed_ms() { 28 | compute_ms(); 29 | return elapsed_time_; 30 | } 31 | 32 | float get_elapsed_s() { 33 | compute_ms(); 34 | return elapsed_time_ * 0.001f; 35 | } 36 | ~gpu_timer() { 37 | cudaEventDestroy(start_); 38 | cudaEventDestroy(stop_); 39 | }; 40 | 41 | private: 42 | void compute_ms() { 43 | cudaEventSynchronize(stop_); 44 | cudaEventElapsedTime(&elapsed_time_, start_, stop_); 45 | } 46 | cudaEvent_t start_, stop_; 47 | cudaStream_t stream_; 48 | float elapsed_time_ = 0.0f; 49 | }; -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/perf_report.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | 19 | void std_cout_perf_report(float insertion_s, 20 | float find_s, 21 | std::size_t num_insertions, 22 | std::size_t num_finds) { 23 | std::cout << "inserted: " << num_insertions << " keys" << '\n'; 24 | std::cout << "finds: " << num_finds << " keys" << '\n'; 25 | 26 | double insertion_rate = double(num_insertions) * 1e-6 / insertion_s; 27 | double find_rate = double(num_finds) * 1e-6 / find_s; 28 | 29 | std::cout << "insert_rate: " << insertion_rate << " Mkey/s" << '\n'; 30 | std::cout << "find_rate: " << find_rate << " Mkey/s" << '\n'; 31 | } 32 | -------------------------------------------------------------------------------- /sampling_server/src/include/hashmap/rkg.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 The Regents of the University of California, Davis 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | namespace rkg { 24 | template 25 | value_type generate_value(key_type in) { 26 | return in + 1; 27 | } 28 | 29 | template 30 | void generate_uniform_unique_pairs(std::vector& keys, 31 | std::vector& values, 32 | size_type num_keys, 33 | bool cache = false, 34 | key_type min_key = 0) { 35 | keys.resize(num_keys); 36 | values.resize(num_keys); 37 | unsigned seed = 1; 38 | // bool cache = true; 39 | std::string dataset_dir = "dataset"; 40 | std::string dataset_name = std::to_string(num_keys) + "_" + std::to_string(seed); 41 | std::string dataset_path = dataset_dir + "/" + dataset_name; 42 | if (cache) { 43 | if (std::experimental::filesystem::exists(dataset_dir)) { 44 | if (std::experimental::filesystem::exists(dataset_path)) { 45 | std::cout << "Reading cached keys.." << std::endl; 46 | std::ifstream dataset(dataset_path, std::ios::binary); 47 | dataset.read((char*)keys.data(), sizeof(key_type) * num_keys); 48 | dataset.read((char*)values.data(), sizeof(value_type) * num_keys); 49 | dataset.close(); 50 | return; 51 | } 52 | } else { 53 | std::experimental::filesystem::create_directory(dataset_dir); 54 | } 55 | } 56 | std::random_device rd; 57 | std::mt19937 rng(seed); 58 | auto max_key = std::numeric_limits::max() - 1; 59 | std::uniform_int_distribution uni(min_key, max_key); 60 | std::unordered_set unique_keys; 61 | while (unique_keys.size() < num_keys) { 62 | unique_keys.insert(uni(rng)); 63 | // unique_keys.insert(unique_keys.size() + 1); 64 | } 65 | std::copy(unique_keys.cbegin(), unique_keys.cend(), keys.begin()); 66 | std::shuffle(keys.begin(), keys.end(), rng); 67 | 68 | #ifdef _WIN32 69 | // OpenMP + windows don't allow unsigned loops 70 | for (uint32_t i = 0; i < unique_keys.size(); i++) { 71 | values[i] = generate_value(keys[i]); 72 | } 73 | #else 74 | 75 | for (uint32_t i = 0; i < unique_keys.size(); i++) { 76 | values[i] = generate_value(keys[i]); 77 | } 78 | #endif 79 | 80 | if (cache) { 81 | std::cout << "Caching.." << std::endl; 82 | std::ofstream dataset(dataset_path, std::ios::binary); 83 | dataset.write((char*)keys.data(), sizeof(key_type) * num_keys); 84 | dataset.write((char*)values.data(), sizeof(value_type) * num_keys); 85 | dataset.close(); 86 | } 87 | } 88 | 89 | template 90 | void generate_uniform_unique_keys(std::vector& keys, size_type num_keys) { 91 | keys.resize(num_keys); 92 | unsigned seed = 1; 93 | std::random_device rd; 94 | std::mt19937 rng(seed); 95 | auto max_key = std::numeric_limits::max() - 1; 96 | std::uniform_int_distribution uni(0, max_key); 97 | std::unordered_set unique_keys; 98 | while (unique_keys.size() < num_keys) { 99 | unique_keys.insert(uni(rng)); 100 | } 101 | std::copy(unique_keys.cbegin(), unique_keys.cend(), keys.begin()); 102 | std::shuffle(keys.begin(), keys.end(), rng); 103 | } 104 | } // namespace rkg 105 | -------------------------------------------------------------------------------- /sampling_server/src/include/system_config.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /*-----------------------------iostack parameters---------------------------*/ 4 | #include 5 | #define REG_SIZE 0x4000 // BAR 0 mapped size 6 | #define REG_CC 0x14 // addr: controller configuration 7 | #define REG_CC_EN 0x1 // mask: enable controller 8 | #define REG_CSTS 0x1c // addr: controller status 9 | #define REG_CSTS_RDY 0x1 // mask: controller ready 10 | #define REG_AQA 0x24 // addr: admin queue attributes 11 | #define REG_ASQ 0x28 // addr: admin submission queue base addr 12 | #define REG_ACQ 0x30 // addr: admin completion queue base addr 13 | #define REG_SQTDBL 0x1000 // addr: submission queue 0 tail doorbell 14 | #define REG_CQHDBL 0x1004 // addr: completion queue 0 sq_tail doorbell 15 | #define DBL_STRIDE 8 16 | #define PHASE_MASK 0x10000 // mask: phase tag 17 | #define HOST_PGSZ 0x1000 18 | #define DEVICE_PGSZ 0x10000 19 | #define CID_MASK 0xffff // mask: command id 20 | #define SC_MASK 0xff // mask: status code 21 | #define BROADCAST_NSID 0 // broadcast namespace id 22 | #define OPCODE_SET_FEATURES 0x09 23 | #define OPCODE_CREATE_IO_CQ 0x05 24 | #define OPCODE_CREATE_IO_SQ 0x01 25 | #define OPCODE_READ 0x02 26 | #define OPCODE_WRITE 0x01 27 | #define FID_NUM_QUEUES 0x07 28 | #define LB_SIZE 0x200 29 | #define RW_RETRY_MASK 0x80000000 30 | #define SQ_ITEM_SIZE 64 31 | #define WARP_SIZE 32 32 | #define SQ_HEAD_MASK 0xffff 33 | 34 | #define MAX_IO_SIZE 4096 35 | #define LBS 512 36 | #define MAX_ITEMS (MAX_IO_SIZE / LBS) 37 | #define NUM_THREADS_PER_BLOCK 512 38 | #define ADMIN_QUEUE_DEPTH 64 39 | #define QUEUE_DEPTH 4096 40 | #define QUEUE_IOBUF_SIZE (MAX_IO_SIZE * QUEUE_DEPTH) 41 | #define NUM_PRP_ENTRIES (MAX_IO_SIZE / HOST_PGSZ) 42 | #define PRP_SIZE (NUM_PRP_ENTRIES * sizeof(uint64_t)) 43 | #define NUM_LBS_PER_SSD 0x100000000 44 | #define MAX_SSDS_SUPPORTED 16 45 | 46 | 47 | #define INTERBATCH_CON 2 //inter-batch pipeline concurrency 48 | #define INTRABATCH_CON 3 //intra-batch pipeline concurrency 49 | 50 | #define MAX_DEVICE 8 51 | #define MEMORY_USAGE 7 52 | #define TRAINMODE 0 53 | #define VALIDMODE 1 54 | #define TESTMODE 2 55 | 56 | #define CACHEMISS_FLAG -2 57 | #define CACHECPU_FLAG -1 58 | 59 | #define CHECK(ans) gpuAssert((ans), __FILE__, __LINE__) 60 | 61 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) 62 | { 63 | if (code != cudaSuccess) 64 | { 65 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 66 | if (abort) 67 | exit(1); 68 | } 69 | } -------------------------------------------------------------------------------- /sampling_server/src/storage/feature_storage.cu: -------------------------------------------------------------------------------- 1 | #include "feature_storage.cuh" 2 | #include "feature_storage_impl.cuh" 3 | #include 4 | 5 | class CompleteFeatureStorage : public FeatureStorage{ 6 | public: 7 | CompleteFeatureStorage(){ 8 | } 9 | 10 | virtual ~CompleteFeatureStorage(){}; 11 | 12 | void Build(BuildInfo* info) override { 13 | iostack_ = new IOStack(info->num_ssd, info->num_queues_per_ssd, 1, 32); 14 | num_ssd_ = info->num_ssd; 15 | std::cout<<"IOStack built\n"; 16 | queue_ = new UserQueue(32, 1024, 4000000); 17 | std::cout<<"UserQueue built\n"; 18 | 19 | int32_t partition_count = info->partition_count; 20 | total_num_nodes_ = info->total_num_nodes; 21 | float_feature_len_ = info->float_feature_len; 22 | float* host_float_feature = info->host_float_feature; 23 | 24 | cudaSetDevice(0); 25 | cudaMalloc(&d_num_req_, sizeof(int32_t)); 26 | cudaMemset(d_num_req_, 0, sizeof(int32_t)); 27 | 28 | training_set_num_.resize(partition_count); 29 | training_set_ids_.resize(partition_count); 30 | training_labels_.resize(partition_count); 31 | 32 | validation_set_num_.resize(partition_count); 33 | validation_set_ids_.resize(partition_count); 34 | validation_labels_.resize(partition_count); 35 | 36 | testing_set_num_.resize(partition_count); 37 | testing_set_ids_.resize(partition_count); 38 | testing_labels_.resize(partition_count); 39 | 40 | partition_count_ = partition_count; 41 | 42 | for(int32_t i = 0; i < info->shard_to_partition.size(); i++){ 43 | int32_t part_id = info->shard_to_partition[i]; 44 | int32_t device_id = info->shard_to_device[i]; 45 | 46 | training_set_num_[part_id] = info->training_set_num[part_id]; 47 | 48 | validation_set_num_[part_id] = info->validation_set_num[part_id]; 49 | testing_set_num_[part_id] = info->testing_set_num[part_id]; 50 | 51 | cudaSetDevice(device_id); 52 | cudaCheckError(); 53 | 54 | 55 | int32_t* train_ids; 56 | cudaMalloc(&train_ids, training_set_num_[part_id] * sizeof(int32_t)); 57 | cudaMemcpy(train_ids, info->training_set_ids[part_id].data(), training_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 58 | training_set_ids_[part_id] = train_ids; 59 | cudaCheckError(); 60 | 61 | int32_t* valid_ids; 62 | cudaMalloc(&valid_ids, validation_set_num_[part_id] * sizeof(int32_t)); 63 | cudaMemcpy(valid_ids, info->validation_set_ids[part_id].data(), validation_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 64 | validation_set_ids_[part_id] = valid_ids; 65 | cudaCheckError(); 66 | 67 | int32_t* test_ids; 68 | cudaMalloc(&test_ids, testing_set_num_[part_id] * sizeof(int32_t)); 69 | cudaMemcpy(test_ids, info->testing_set_ids[part_id].data(), testing_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 70 | testing_set_ids_[part_id] = test_ids; 71 | cudaCheckError(); 72 | 73 | int32_t* train_labels; 74 | cudaMalloc(&train_labels, training_set_num_[part_id] * sizeof(int32_t)); 75 | cudaMemcpy(train_labels, info->training_labels[part_id].data(), training_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 76 | training_labels_[part_id] = train_labels; 77 | cudaCheckError(); 78 | 79 | int32_t* valid_labels; 80 | cudaMalloc(&valid_labels, validation_set_num_[part_id] * sizeof(int32_t)); 81 | cudaMemcpy(valid_labels, info->validation_labels[part_id].data(), validation_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 82 | validation_labels_[part_id] = valid_labels; 83 | cudaCheckError(); 84 | 85 | int32_t* test_labels; 86 | cudaMalloc(&test_labels, testing_set_num_[part_id] * sizeof(int32_t)); 87 | cudaMemcpy(test_labels, info->testing_labels[part_id].data(), testing_set_num_[part_id] * sizeof(int32_t), cudaMemcpyHostToDevice); 88 | testing_labels_[part_id] = test_labels; 89 | cudaCheckError(); 90 | 91 | } 92 | 93 | cudaMalloc(&d_req_count_, sizeof(unsigned long long)); 94 | cudaMemset(d_req_count_, 0, sizeof(unsigned long long)); 95 | cudaCheckError(); 96 | 97 | }; 98 | 99 | void Finalize() override { 100 | // cudaFreeHost(float_feature_); 101 | for(int32_t i = 0; i < partition_count_; i++){ 102 | cudaSetDevice(i); 103 | cudaFree(training_set_ids_[i]); 104 | cudaFree(validation_set_ids_[i]); 105 | cudaFree(testing_set_ids_[i]); 106 | cudaFree(training_labels_[i]); 107 | cudaFree(validation_labels_[i]); 108 | cudaFree(testing_labels_[i]); 109 | } 110 | } 111 | 112 | int32_t* GetTrainingSetIds(int32_t part_id) const override { 113 | return training_set_ids_[part_id]; 114 | } 115 | int32_t* GetValidationSetIds(int32_t part_id) const override { 116 | return validation_set_ids_[part_id]; 117 | } 118 | int32_t* GetTestingSetIds(int32_t part_id) const override { 119 | return testing_set_ids_[part_id]; 120 | } 121 | 122 | int32_t* GetTrainingLabels(int32_t part_id) const override { 123 | return training_labels_[part_id]; 124 | }; 125 | int32_t* GetValidationLabels(int32_t part_id) const override { 126 | return validation_labels_[part_id]; 127 | } 128 | int32_t* GetTestingLabels(int32_t part_id) const override { 129 | return testing_labels_[part_id]; 130 | } 131 | 132 | int32_t TrainingSetSize(int32_t part_id) const override { 133 | return training_set_num_[part_id]; 134 | } 135 | int32_t ValidationSetSize(int32_t part_id) const override { 136 | return validation_set_num_[part_id]; 137 | } 138 | int32_t TestingSetSize(int32_t part_id) const override { 139 | return testing_set_num_[part_id]; 140 | } 141 | 142 | int32_t TotalNodeNum() const override { 143 | return total_num_nodes_; 144 | } 145 | 146 | float* GetAllFloatFeature() const override { 147 | return float_feature_; 148 | } 149 | int32_t GetFloatFeatureLen() const override { 150 | return float_feature_len_; 151 | } 152 | 153 | void Print(BuildInfo* info) override { 154 | } 155 | 156 | void IOSubmit(int32_t* sampled_ids, int32_t* cache_index, 157 | int32_t* node_counter, float* dst_float_buffer, 158 | int32_t op_id, cudaStream_t strm_hdl) override { 159 | 160 | IOReq* req = queue_->dequeue(node_counter, op_id, cache_index, sampled_ids, d_num_req_, dst_float_buffer , float_feature_len_, num_ssd_, strm_hdl); 161 | cudaCheckError(); 162 | iostack_->io_submission(req, d_num_req_, strm_hdl); // use device pointer to store request number, avoid CPU-GPU synchronization 163 | cudaCheckError(); 164 | } 165 | 166 | void IOComplete(cudaStream_t strm_hdl) override { 167 | 168 | iostack_->io_completion(strm_hdl); 169 | 170 | } 171 | 172 | private: 173 | std::vector training_set_num_; 174 | std::vector validation_set_num_; 175 | std::vector testing_set_num_; 176 | 177 | std::vector training_set_ids_; 178 | std::vector validation_set_ids_; 179 | std::vector testing_set_ids_; 180 | 181 | std::vector training_labels_; 182 | std::vector validation_labels_; 183 | std::vector testing_labels_; 184 | 185 | int32_t partition_count_; 186 | int32_t total_num_nodes_; 187 | float* float_feature_; 188 | int32_t float_feature_len_; 189 | 190 | unsigned long long* d_req_count_; 191 | 192 | int32_t num_ssd_; 193 | 194 | IOStack* iostack_;//single GPU multi-SSD 195 | UserQueue* queue_; 196 | int32_t* d_num_req_; 197 | // IOReq* h_reqs_; 198 | // IOReq* d_reqs_; 199 | // float *app_buf_; 200 | friend FeatureStorage* NewCompleteFeatureStorage(); 201 | }; 202 | 203 | extern "C" 204 | FeatureStorage* NewCompleteFeatureStorage(){ 205 | CompleteFeatureStorage* ret = new CompleteFeatureStorage(); 206 | return ret; 207 | } -------------------------------------------------------------------------------- /sampling_server/src/storage/feature_storage.cuh: -------------------------------------------------------------------------------- 1 | #ifndef FEATURE_STORAGE_H_ 2 | #define FEATURE_STORAGE_H_ 3 | 4 | #include "buildinfo.h" 5 | 6 | class FeatureStorage { 7 | public: 8 | virtual ~FeatureStorage() = default; 9 | 10 | virtual void Build(BuildInfo* info) = 0; 11 | virtual void Finalize() = 0; 12 | 13 | virtual int32_t* GetTrainingSetIds(int32_t part_id) const = 0; 14 | virtual int32_t* GetValidationSetIds(int32_t part_id) const = 0; 15 | virtual int32_t* GetTestingSetIds(int32_t part_id) const = 0; 16 | 17 | virtual int32_t* GetTrainingLabels(int32_t part_id) const = 0; 18 | virtual int32_t* GetValidationLabels(int32_t part_id) const = 0; 19 | virtual int32_t* GetTestingLabels(int32_t part_id) const = 0; 20 | 21 | virtual int32_t TrainingSetSize(int32_t part_id) const = 0; 22 | virtual int32_t ValidationSetSize(int32_t part_id) const = 0; 23 | virtual int32_t TestingSetSize(int32_t part_id) const = 0; 24 | 25 | virtual int32_t TotalNodeNum() const = 0; 26 | virtual float* GetAllFloatFeature() const = 0; 27 | virtual int32_t GetFloatFeatureLen() const = 0; 28 | 29 | virtual void Print(BuildInfo* info) = 0; 30 | 31 | virtual void IOSubmit(int32_t* sampled_ids, int32_t* cache_index, 32 | int32_t* node_counter, float* dst_float_buffer, 33 | int32_t op_id, cudaStream_t strm_hdl) = 0; 34 | 35 | virtual void IOComplete(cudaStream_t strm_hdl) = 0; 36 | 37 | }; 38 | 39 | extern "C" 40 | FeatureStorage* NewCompleteFeatureStorage(); 41 | 42 | #endif -------------------------------------------------------------------------------- /sampling_server/src/storage/feature_storage_impl.cuh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef FEATURE_STORAGE_IMPL_H_ 3 | #define FEATURE_STORAGE_IMPL_H_ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "iostack.cuh" 10 | #include "userqueue.cuh" 11 | 12 | // Macro for checking cuda errors following a cuda launch or api call 13 | #define cudaCheckError() \ 14 | { \ 15 | cudaError_t e = cudaGetLastError(); \ 16 | if (e != cudaSuccess) { \ 17 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 18 | cudaGetErrorString(e)); \ 19 | exit(EXIT_FAILURE); \ 20 | } \ 21 | } 22 | 23 | 24 | #endif -------------------------------------------------------------------------------- /sampling_server/src/storage/graph_storage.cu: -------------------------------------------------------------------------------- 1 | #include "graph_storage.cuh" 2 | #include "graph_storage_impl.cuh" 3 | 4 | /*in this version, partition id = shard id = device id*/ 5 | class CompleteGraphStorage : public GraphStorage { 6 | public: 7 | CompleteGraphStorage() { 8 | } 9 | 10 | virtual ~CompleteGraphStorage() { 11 | } 12 | 13 | void Build(BuildInfo* info) override { 14 | int32_t partition_count = info->partition_count; 15 | partition_count_ = partition_count; 16 | node_num_ = info->total_num_nodes; 17 | edge_num_ = info->total_edge_num; 18 | cache_edge_num_ = info->cache_edge_num; 19 | 20 | // shard count == partition count now 21 | csr_node_index_.resize(partition_count_); 22 | csr_dst_node_ids_.resize(partition_count_); 23 | partition_index_.resize(partition_count_); 24 | partition_offset_.resize(partition_count_); 25 | 26 | d_global_count_.resize(partition_count); 27 | h_global_count_.resize(partition_count); 28 | h_cache_hit_.resize(partition_count); 29 | find_iter_.resize(partition_count); 30 | h_batch_size_.resize(partition_count); 31 | 32 | for(int32_t i = 0; i < partition_count; i++){ 33 | cudaSetDevice(i); 34 | cudaMalloc(&csr_node_index_[i], 4 * sizeof(int64_t*)); 35 | cudaMalloc(&csr_dst_node_ids_[i], 4 * sizeof(int32_t*)); 36 | cudaMalloc(&d_global_count_[i], 4); 37 | h_global_count_[i] = (int32_t*)malloc(4); 38 | h_cache_hit_[i] = 0; 39 | find_iter_[i] = 0; 40 | h_batch_size_[i] = 0; 41 | } 42 | 43 | src_size_.resize(partition_count); 44 | dst_size_.resize(partition_count); 45 | cudaCheckError(); 46 | 47 | cudaSetDevice(0); 48 | 49 | int64_t* pin_csr_node_index; 50 | int32_t* pin_csr_dst_node_ids; 51 | 52 | h_csr_node_index_ = info->csr_node_index; 53 | h_csr_dst_node_ids_ = info->csr_dst_node_ids; 54 | 55 | // int64_t* d_csr_node_index; 56 | // int32_t* d_csr_dst_node_ids; 57 | // cudaMalloc(&d_csr_node_index, (node_num_ + 1) * sizeof(int64_t)); 58 | // cudaMalloc(&d_csr_dst_node_ids, edge_num_ * sizeof(int32_t)); 59 | // cudaMemcpy(d_csr_node_index, h_csr_node_index_, (node_num_ + 1) * sizeof(int64_t), cudaMemcpyHostToDevice); 60 | // cudaMemcpy(d_csr_dst_node_ids, h_csr_dst_node_ids_, edge_num_ * sizeof(int32_t), cudaMemcpyHostToDevice); 61 | 62 | cudaHostGetDevicePointer(&pin_csr_node_index, h_csr_node_index_, 0); 63 | cudaHostGetDevicePointer(&pin_csr_dst_node_ids, h_csr_dst_node_ids_, 0); 64 | assign_memory<<<1,1>>>(csr_dst_node_ids_[0], pin_csr_dst_node_ids, csr_node_index_[0], pin_csr_node_index, 2); 65 | // assign_memory<<<1,1>>>(csr_dst_node_ids_[0], d_csr_dst_node_ids, csr_node_index_[0], d_csr_node_index, 2); 66 | 67 | cudaCheckError(); 68 | // assign_memory<<<1,1>>>(csr_dst_node_ids_[0], d_csr_dst_node_ids, csr_node_index_[0], d_csr_node_index, partition_count); 69 | // cudaCheckError(); 70 | csr_node_index_cpu_ = pin_csr_node_index; 71 | csr_dst_node_ids_cpu_ = pin_csr_dst_node_ids; 72 | 73 | } 74 | 75 | void HyrbidGraphCache(int32_t* QT, int32_t cpu_capacity, int32_t gpu_capacity){ 76 | int64_t* neighbor_count; 77 | cudaMalloc(&neighbor_count, (cpu_capacity + gpu_capacity) * sizeof(int64_t)); 78 | GetNeighborCount<<<80, 1024>>>(QT, 1, 0, (cpu_capacity + gpu_capacity), csr_node_index_cpu_, neighbor_count); 79 | int64_t* d_csr_node_index; 80 | cudaMalloc(&d_csr_node_index, (int64_t(gpu_capacity + 1)*sizeof(int64_t))); 81 | cudaMemset(d_csr_node_index, 0, (int64_t(gpu_capacity + 1)*sizeof(int64_t))); 82 | 83 | thrust::inclusive_scan(thrust::device, neighbor_count, neighbor_count + gpu_capacity, d_csr_node_index + 1); 84 | cudaCheckError(); 85 | int64_t* h_csr_node_index = (int64_t*)malloc((gpu_capacity + 1) * sizeof(int64_t)); 86 | cudaMemcpy(h_csr_node_index, d_csr_node_index, (gpu_capacity + 1) * sizeof(int64_t), cudaMemcpyDeviceToHost); 87 | 88 | int32_t* d_csr_dst_node_ids; 89 | cudaMalloc(&d_csr_dst_node_ids, int64_t(int64_t(h_csr_node_index[gpu_capacity]) * sizeof(int32_t))); 90 | 91 | TopoFillUp<<<80, 1024>>>(QT, 1, 0, gpu_capacity, csr_node_index_cpu_, csr_dst_node_ids_cpu_, d_csr_node_index, d_csr_dst_node_ids); 92 | cudaCheckError(); 93 | assign_memory<<<1,1>>>(csr_dst_node_ids_[0], d_csr_dst_node_ids, csr_node_index_[0], d_csr_node_index, 0); 94 | 95 | int64_t* p_csr_node_index; 96 | cudaHostAlloc(&p_csr_node_index, (int64_t(cpu_capacity + 1)*sizeof(int64_t)), cudaHostAllocMapped); 97 | int64_t* pin_csr_node_index; 98 | cudaHostGetDevicePointer(&pin_csr_node_index, p_csr_node_index, 0); 99 | cudaMemset(pin_csr_node_index, 0, (int64_t(cpu_capacity + 1)*sizeof(int64_t))); 100 | if(cpu_capacity > 0){ 101 | thrust::inclusive_scan(thrust::device, neighbor_count + gpu_capacity, neighbor_count + gpu_capacity + cpu_capacity, pin_csr_node_index + 1); 102 | cudaCheckError(); 103 | int32_t* p_csr_dst_node_ids; 104 | cudaHostAlloc(&p_csr_dst_node_ids, int64_t(int64_t(p_csr_node_index[cpu_capacity]) * sizeof(int32_t)), cudaHostAllocMapped); 105 | int32_t* pin_csr_dst_node_ids; 106 | cudaHostGetDevicePointer(&pin_csr_dst_node_ids, p_csr_dst_node_ids, 0); 107 | TopoFillUp<<<80, 1024>>>(QT+gpu_capacity, 1, 0, cpu_capacity, csr_node_index_cpu_, csr_dst_node_ids_cpu_, pin_csr_node_index, pin_csr_dst_node_ids); 108 | cudaCheckError(); 109 | assign_memory<<<1,1>>>(csr_dst_node_ids_[0], pin_csr_dst_node_ids, csr_node_index_[0], pin_csr_node_index, 1); 110 | } 111 | 112 | cudaCheckError(); 113 | cudaFree(neighbor_count); 114 | } 115 | 116 | void Finalize() override { 117 | cudaFreeHost(csr_node_index_cpu_); 118 | cudaFreeHost(csr_dst_node_ids_cpu_); 119 | } 120 | 121 | //CSR 122 | int32_t GetPartitionCount() const override { 123 | return partition_count_; 124 | } 125 | int64_t** GetCSRNodeIndex(int32_t dev_id) const override { 126 | return csr_node_index_[dev_id]; 127 | } 128 | int32_t** GetCSRNodeMatrix(int32_t dev_id) const override { 129 | return csr_dst_node_ids_[dev_id]; 130 | } 131 | 132 | int64_t* GetCSRNodeIndexCPU() const override { 133 | return csr_node_index_cpu_; 134 | } 135 | 136 | int32_t* GetCSRNodeMatrixCPU() const override { 137 | return csr_dst_node_ids_cpu_; 138 | } 139 | 140 | int64_t Src_Size(int32_t part_id) const override { 141 | return src_size_[part_id]; 142 | } 143 | int64_t Dst_Size(int32_t part_id) const override { 144 | return dst_size_[part_id]; 145 | } 146 | int32_t* PartitionIndex(int32_t dev_id) const override { 147 | return partition_index_[dev_id]; 148 | } 149 | int32_t* PartitionOffset(int32_t dev_id) const override { 150 | return partition_offset_[dev_id]; 151 | } 152 | 153 | private: 154 | std::vector src_size_; 155 | std::vector dst_size_; 156 | 157 | int32_t node_num_; 158 | int64_t edge_num_; 159 | int64_t cache_edge_num_; 160 | 161 | //CSR graph, every partition has a ptr copy 162 | int32_t partition_count_; 163 | std::vector csr_node_index_; 164 | std::vector csr_dst_node_ids_; 165 | int64_t* csr_node_index_cpu_; 166 | int32_t* csr_dst_node_ids_cpu_; 167 | 168 | int64_t* h_csr_node_index_; 169 | int32_t* h_csr_dst_node_ids_; 170 | 171 | std::vector partition_index_; 172 | std::vector partition_offset_; 173 | 174 | std::vector h_global_count_; 175 | std::vector d_global_count_; 176 | 177 | 178 | std::vector find_iter_; 179 | std::vector h_cache_hit_; 180 | std::vector h_batch_size_; 181 | }; 182 | 183 | extern "C" 184 | GraphStorage* NewCompleteGraphStorage(){ 185 | CompleteGraphStorage* ret = new CompleteGraphStorage(); 186 | return ret; 187 | } 188 | -------------------------------------------------------------------------------- /sampling_server/src/storage/graph_storage.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef GRAPH_STORAGE_H_ 3 | #define GRAPH_STORAGE_H_ 4 | 5 | #include "buildinfo.h" 6 | 7 | class GraphStorage { 8 | public: 9 | virtual ~GraphStorage() = default; 10 | //build 11 | virtual void Build(BuildInfo* info) = 0; 12 | virtual void HyrbidGraphCache(int32_t* QT, int32_t cpu_capacity, int32_t gpu_capacity) = 0; 13 | virtual void Finalize() = 0; 14 | //CSR 15 | virtual int32_t GetPartitionCount() const = 0; 16 | virtual int64_t** GetCSRNodeIndex(int32_t dev_id) const = 0; 17 | virtual int32_t** GetCSRNodeMatrix(int32_t dev_id) const = 0; 18 | virtual int64_t* GetCSRNodeIndexCPU() const = 0; 19 | virtual int32_t* GetCSRNodeMatrixCPU() const = 0; 20 | virtual int64_t Src_Size(int32_t part_id) const = 0; 21 | virtual int64_t Dst_Size(int32_t part_id) const = 0; 22 | virtual int32_t* PartitionIndex(int32_t dev_id) const = 0; 23 | virtual int32_t* PartitionOffset(int32_t dev_id) const = 0; 24 | }; 25 | extern "C" 26 | GraphStorage* NewCompleteGraphStorage(); 27 | 28 | #endif // GRAPH_STORAGE_H_ -------------------------------------------------------------------------------- /sampling_server/src/storage/graph_storage_impl.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef GRAPH_STORAGE_IMPL_H_ 3 | #define GRAPH_STORAGE_IMPL_H_ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | // Macro for checking cuda errors following a cuda launch or api call 16 | #define cudaCheckError() \ 17 | { \ 18 | cudaError_t e = cudaGetLastError(); \ 19 | if (e != cudaSuccess) { \ 20 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 21 | cudaGetErrorString(e)); \ 22 | exit(EXIT_FAILURE); \ 23 | } \ 24 | } 25 | 26 | 27 | __global__ void assign_memory(int32_t** int32_pptr, int32_t* int32_ptr, int64_t** int64_pptr, int64_t* int64_ptr, int32_t device_id){ 28 | int32_pptr[device_id] = int32_ptr; 29 | int64_pptr[device_id] = int64_ptr; 30 | } 31 | 32 | 33 | __global__ void GetNeighborCount(int32_t* QT, int32_t Kg, int32_t Ki, int32_t capacity, int64_t* csr_node_index_cpu, int64_t* neighbor_count){ 34 | for(int32_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; thread_idx < capacity; thread_idx += gridDim.x * blockDim.x){ 35 | int32_t cache_id = QT[thread_idx * Kg + Ki]; 36 | int64_t count = csr_node_index_cpu[cache_id + 1] - csr_node_index_cpu[cache_id]; 37 | neighbor_count[thread_idx] = count; 38 | } 39 | } 40 | 41 | __global__ void TopoFillUp(int32_t* QT, int32_t Kg, int32_t Ki, int32_t capacity, 42 | int64_t* csr_node_index_cpu, int32_t* csr_dst_node_ids_cpu, 43 | int64_t* d_csr_node_index, int32_t* d_csr_dst_node_ids){ 44 | for(int32_t thread_idx = threadIdx.x + blockDim.x * blockIdx.x; thread_idx < capacity; thread_idx += gridDim.x * blockDim.x){ 45 | int32_t cache_id = QT[thread_idx * Kg + Ki]; 46 | int64_t count = csr_node_index_cpu[cache_id + 1] - csr_node_index_cpu[cache_id]; 47 | for(int i = 0; i < count; i++){ 48 | int32_t neighbor_id = csr_dst_node_ids_cpu[csr_node_index_cpu[cache_id] + i]; 49 | int64_t start_off = d_csr_node_index[thread_idx]; 50 | d_csr_dst_node_ids[start_off + i] = neighbor_id; 51 | } 52 | } 53 | } 54 | 55 | 56 | 57 | #endif -------------------------------------------------------------------------------- /sampling_server/src/storage/ioctl.h: -------------------------------------------------------------------------------- 1 | #ifndef __NVM_INTERNAL_LINUX_IOCTL_H__ 2 | #define __NVM_INTERNAL_LINUX_IOCTL_H__ 3 | #ifdef __linux__ 4 | 5 | #include 6 | #include 7 | 8 | #define NVM_IOCTL_TYPE 0x80 9 | 10 | 11 | 12 | /* Memory map request */ 13 | struct nvm_ioctl_map 14 | { 15 | uint64_t vaddr_start; 16 | size_t n_pages; 17 | uint64_t* ioaddrs; 18 | }; 19 | 20 | 21 | 22 | /* Supported operations */ 23 | enum nvm_ioctl_type 24 | { 25 | NVM_MAP_HOST_MEMORY = _IOW(NVM_IOCTL_TYPE, 1, struct nvm_ioctl_map), 26 | #ifdef _CUDA 27 | NVM_MAP_DEVICE_MEMORY = _IOW(NVM_IOCTL_TYPE, 2, struct nvm_ioctl_map), 28 | #endif 29 | NVM_UNMAP_MEMORY = _IOW(NVM_IOCTL_TYPE, 3, uint64_t) 30 | }; 31 | 32 | 33 | #endif /* __linux__ */ 34 | #endif /* __NVM_INTERNAL_LINUX_IOCTL_H__ */ 35 | -------------------------------------------------------------------------------- /sampling_server/src/storage/map.h: -------------------------------------------------------------------------------- 1 | #ifndef __NVM_INTERNAL_LINUX_MAP_H__ 2 | #define __NVM_INTERNAL_LINUX_MAP_H__ 3 | #ifdef __linux__ 4 | 5 | #include "linux/ioctl.h" 6 | #include "dma.h" 7 | 8 | 9 | /* 10 | * What kind of memory are we mapping. 11 | */ 12 | enum mapping_type 13 | { 14 | MAP_TYPE_CUDA = 0x1, // CUDA device memory 15 | MAP_TYPE_HOST = 0x2, // Host memory (RAM) 16 | MAP_TYPE_API = 0x4 // Allocated by the API (RAM) 17 | }; 18 | 19 | 20 | 21 | /* 22 | * Mapping container 23 | */ 24 | struct ioctl_mapping 25 | { 26 | enum mapping_type type; // What kind of memory 27 | void* buffer; 28 | struct va_range range; // Memory range descriptor 29 | }; 30 | 31 | 32 | #endif /* __linux__ */ 33 | #endif /* __NVM_INTERNAL_LINUX_MAP_H__ */ 34 | -------------------------------------------------------------------------------- /sampling_server/src/storage/ssdqp.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "system_config.cuh" 5 | class SSDQueuePair 6 | { 7 | public: 8 | volatile uint32_t *sq; 9 | volatile uint32_t *cq; 10 | uint32_t sq_tail; 11 | uint32_t sq_tail_old; 12 | uint32_t cq_head; 13 | uint32_t cmd_id; // also number of commands submitted 14 | uint32_t namespace_id; 15 | uint32_t *sqtdbl, *cqhdbl; 16 | uint32_t *cmd_id_to_req_id; 17 | uint32_t *cmd_id_to_sq_pos; 18 | bool *sq_entry_busy; 19 | uint32_t queue_depth; 20 | uint32_t num_completed; 21 | 22 | __host__ __device__ SSDQueuePair() 23 | { 24 | } 25 | 26 | __host__ __device__ SSDQueuePair(volatile uint32_t *sq, volatile uint32_t *cq, uint32_t namespace_id, uint32_t *sqtdbl, uint32_t *cqhdbl, uint32_t queue_depth, uint32_t *cmd_id_to_req_id = nullptr, uint32_t *cmd_id_to_sq_pos = nullptr, bool *sq_entry_busy = nullptr) 27 | : sq(sq), cq(cq), sq_tail(0), cq_head(0), cmd_id(0), namespace_id(namespace_id), sqtdbl(sqtdbl), cqhdbl(cqhdbl), cmd_id_to_req_id(cmd_id_to_req_id), cmd_id_to_sq_pos(cmd_id_to_sq_pos), sq_entry_busy(sq_entry_busy), queue_depth(queue_depth), num_completed(0) 28 | { 29 | sq_tail_old = 0; 30 | } 31 | 32 | __host__ __device__ void submit(uint32_t &cid, uint32_t opcode, uint64_t prp1, uint64_t prp2, uint32_t dw10, uint32_t dw11, uint32_t dw12 = 0) 33 | { 34 | // printf("%lx %lx %x %x %x %x %x\n", prp1, prp2, dw10, dw11, dw12, opcode, cmd_id); 35 | fill_sq(cmd_id, sq_tail, opcode, prp1, prp2, dw10, dw11, dw12); 36 | sq_tail = (sq_tail + 1) % queue_depth; 37 | *sqtdbl = sq_tail; 38 | cid = cmd_id; 39 | cmd_id = (cmd_id + 1) & CID_MASK; 40 | } 41 | 42 | __host__ __device__ void fill_sq(uint32_t cid, uint32_t pos, uint32_t opcode, uint64_t prp1, uint64_t prp2, uint32_t dw10, uint32_t dw11, uint32_t dw12 = 0, uint32_t req_id = 0xffffffff) 43 | { 44 | // if (req_id == 1152) 45 | sq[pos * 16 + 0] = opcode | (cid << 16); 46 | sq[pos * 16 + 1] = namespace_id; 47 | sq[pos * 16 + 6] = prp1 & 0xffffffff; 48 | sq[pos * 16 + 7] = prp1 >> 32; 49 | sq[pos * 16 + 8] = prp2 & 0xffffffff; 50 | sq[pos * 16 + 9] = prp2 >> 32; 51 | sq[pos * 16 + 10] = dw10; 52 | sq[pos * 16 + 11] = dw11; 53 | sq[pos * 16 + 12] = dw12; 54 | // printf("%u %u %u %u %u %u %u %u\n", opcode | (cid << 16), namespace_id, prp1 & 0xffffffff, prp1 >> 32, prp2 & 0xffffffff, prp2 >> 32, dw10, dw11, dw12); 55 | // printf("%u, %u\n", namespace_id, opcode); 56 | if (cmd_id_to_req_id) 57 | cmd_id_to_req_id[cid % queue_depth] = req_id; 58 | if (cmd_id_to_sq_pos) 59 | cmd_id_to_sq_pos[cid % queue_depth] = pos; 60 | if (sq_entry_busy) 61 | sq_entry_busy[pos] = true; 62 | } 63 | 64 | __host__ __device__ void poll(uint32_t &code, uint32_t cid) 65 | { 66 | uint32_t current_phase = ((cmd_id - 1) / queue_depth) & 1; 67 | uint32_t status = cq[cq_head * 4 + 3]; 68 | while (((status & PHASE_MASK) >> 16) == current_phase) 69 | status = cq[cq_head * 4 + 3]; 70 | if ((status & CID_MASK) != cid) 71 | { 72 | printf("expected cid: %d, actual cid: %d\n", cid, status & CID_MASK); 73 | assert(0); 74 | } 75 | cq_head = (cq_head + 1) % queue_depth; 76 | *cqhdbl = cq_head; 77 | code = (status >> 17) & SC_MASK; 78 | num_completed++; 79 | } 80 | 81 | __device__ uint32_t poll_range(int expected_sq_head, bool should_break) 82 | { 83 | // printf("cmd_id: %d, size: %d, current_phase: %d\n", cmd_id, size, current_phase); 84 | int i; 85 | uint32_t last_sq_head = ~0U; 86 | int last_num_completed = num_completed; 87 | int thread_id = threadIdx.x + blockIdx.x * blockDim.x; 88 | for (i = cq_head; (num_completed & CID_MASK) != cmd_id; i = (i + 1) % queue_depth) 89 | { 90 | uint32_t current_phase = (num_completed / queue_depth) & 1; 91 | uint32_t status = cq[i * 4 + 3]; 92 | uint64_t start = clock64(); 93 | while (((status & PHASE_MASK) >> 16) == current_phase) 94 | { 95 | status = cq[i * 4 + 3]; 96 | if (clock64() - start > 1000000000) 97 | { 98 | printf("timeout sq_tail=%d, cq_head=%d, i=%d, num_completed=%d, cmd_id=%d\n", sq_tail, cq_head, i, num_completed, cmd_id); 99 | printf("last_sq_head: %d, expected_sq_head: %d\n", last_sq_head, expected_sq_head); 100 | // int thread_id = blockIdx.x * blockDim.x + threadIdx.x; 101 | // if (thread_id) 102 | // return 0; 103 | // for (int m = 0; m < queue_depth; m++) 104 | // { 105 | // printf("SQE %d\n", m); 106 | // for (int n = 0; n < 16; n++) 107 | // printf("DW%2d, %08x\n", n, sq[m * 16 + n]); 108 | // } 109 | // for (int m = 0; m < queue_depth; m++) 110 | // { 111 | // printf("CQE %d\n", m); 112 | // for (int n = 0; n < 4; n++) 113 | // printf("DW%2d, %08x\n", n, cq[m * 4 + n]); 114 | // } 115 | return 1; 116 | } 117 | } 118 | int cmd_id = status & CID_MASK; 119 | int sq_pos = cmd_id_to_sq_pos[cmd_id % queue_depth]; 120 | if ((status >> 17) & SC_MASK) 121 | { 122 | printf("cq[%d] status: 0x%x, cid: %d\n", i, (status >> 17) & SC_MASK, status & CID_MASK); 123 | int req_id = cmd_id_to_req_id[cmd_id % queue_depth]; 124 | printf("req_id: %d, sq_pos: %d\n", req_id, sq_pos); 125 | // for (int i = 0; i < 16; i++) 126 | // printf("%08x ", sq[sq_pos * 16 + i]); 127 | // printf("\n"); 128 | return (status >> 17) & SC_MASK; 129 | } 130 | last_sq_head = cq[i * 4 + 2] & SQ_HEAD_MASK; 131 | sq_entry_busy[sq_pos] = false; 132 | // printf("thread %d freed sq_pos %d\n", thread_id, sq_pos); 133 | num_completed++; 134 | if (should_break && ((cq[i * 4 + 2] & SQ_HEAD_MASK) - expected_sq_head + queue_depth) % queue_depth <= WARP_SIZE) 135 | { 136 | // printf("cq[%d] sq_head: %d, expected_sq_head: %d\n", i, cq[i * 4 + 2] & SQ_HEAD_MASK, expected_sq_head); 137 | i = (i + 1) % queue_depth; 138 | if (num_completed - last_num_completed > 64) 139 | printf("%d: %d completed\n", thread_id, num_completed - last_num_completed); 140 | break; 141 | } 142 | } 143 | if (i != cq_head) 144 | { 145 | cq_head = i; 146 | // printf("cq_head is %p, set cqhdbl to %d\n", cqhdbl, cq_head); 147 | *cqhdbl = cq_head; 148 | } 149 | return 0; 150 | } 151 | 152 | __host__ __device__ uint32_t poll_until_sq_entry_free(int expected_sq_pos) 153 | { 154 | int thread_id = blockIdx.x * blockDim.x + threadIdx.x; 155 | int last_num_completed = num_completed; 156 | // printf("thread %d want to free sq_pos: %d num_completed %d cmd_id %d\n", thread_id, expected_sq_pos, num_completed, cmd_id); 157 | int i; 158 | for (i = cq_head; (num_completed & CID_MASK) != cmd_id; i = (i + 1) % queue_depth) 159 | { 160 | uint32_t current_phase = (num_completed / queue_depth) & 1; 161 | uint32_t status = cq[i * 4 + 3]; 162 | while (((status & PHASE_MASK) >> 16) == current_phase) 163 | status = cq[i * 4 + 3]; 164 | int cmd_id = status & CID_MASK; 165 | int sq_pos = cmd_id_to_sq_pos[cmd_id % queue_depth]; 166 | if ((status >> 17) & SC_MASK) 167 | { 168 | printf("cq[%d] status: 0x%x, cid: %d\n", i, (status >> 17) & SC_MASK, status & CID_MASK); 169 | int req_id = cmd_id_to_req_id[cmd_id % queue_depth]; 170 | printf("req_id: %d, sq_pos: %d\n", req_id, sq_pos); 171 | // for (int i = 0; i < 16; i++) 172 | // printf("%08x ", sq[sq_pos * 16 + i]); 173 | // printf("\n"); 174 | return (status >> 17) & SC_MASK; 175 | } 176 | sq_entry_busy[sq_pos] = false; 177 | // printf("thread %d manually freed sq_pos %d\n", thread_id, sq_pos); 178 | num_completed++; 179 | if (sq_pos == expected_sq_pos) 180 | { 181 | cq_head = (i + 1) % queue_depth; 182 | // printf("cq_head is %p, set cqhdbl to %d\n", cqhdbl, cq_head); 183 | *cqhdbl = cq_head; 184 | if (num_completed - last_num_completed > 64) 185 | printf("%d: %d completed\n", thread_id, num_completed - last_num_completed); 186 | return 0; 187 | } 188 | } 189 | // printf("thread %d failed to free sq_pos %d\n", thread_id, expected_sq_pos); 190 | return 1; 191 | } 192 | }; 193 | -------------------------------------------------------------------------------- /sampling_server/src/storage/storage_management.cuh: -------------------------------------------------------------------------------- 1 | #include 2 | #include "graph_storage.cuh" 3 | #include "feature_storage.cuh" 4 | #include "cache.cuh" 5 | #include "ipc_service.h" 6 | 7 | class StorageManagement { 8 | public: 9 | 10 | void Initialze(int32_t shard_count); 11 | 12 | GraphStorage* GetGraph(); 13 | 14 | FeatureStorage* GetFeature(); 15 | 16 | UnifiedCache* GetCache(); 17 | 18 | IPCEnv* GetIPCEnv(); 19 | 20 | int32_t Shard_To_Device(int32_t part_id); 21 | 22 | int32_t Shard_To_Partition(int32_t part_id); 23 | 24 | int32_t Central_Device(); 25 | 26 | private: 27 | void EnableP2PAccess(); 28 | 29 | void ConfigPartition(BuildInfo* info, int32_t shard_count); 30 | 31 | void ReadMetaFIle(BuildInfo* info); 32 | 33 | void LoadGraph(BuildInfo* info); 34 | 35 | void LoadFeature(BuildInfo* info); 36 | 37 | int32_t central_device_; 38 | std::vector shard_to_device_; 39 | std::vector shard_to_partition_; 40 | int32_t partition_; 41 | 42 | int64_t cache_edge_num_; 43 | int64_t edge_num_; 44 | int32_t node_num_; 45 | 46 | int32_t training_set_num_; 47 | int32_t validation_set_num_; 48 | int32_t testing_set_num_; 49 | 50 | int32_t float_feature_len_; 51 | 52 | int64_t cache_memory_; 53 | 54 | std::string dataset_path_; 55 | int32_t raw_batch_size_; 56 | int32_t epoch_; 57 | int32_t num_ssd_; 58 | int32_t num_queues_per_ssd_; 59 | int64_t cpu_topo_size_; 60 | int64_t gpu_topo_size_; 61 | int64_t cpu_feat_size_; 62 | int64_t gpu_feat_size_; 63 | GraphStorage* graph_; 64 | FeatureStorage* feature_; 65 | UnifiedCache* cache_; 66 | IPCEnv* env_; 67 | }; 68 | 69 | 70 | -------------------------------------------------------------------------------- /sampling_server/src/storage/storage_management_impl.cuh: -------------------------------------------------------------------------------- 1 | #ifndef STORAGE_MANAGEMENT_IMPL_H_ 2 | #define STORAGE_MANAGEMENT_IMPL_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | 33 | 34 | // Macro for checking cuda errors following a cuda launch or api call 35 | #define cudaCheckError() \ 36 | { \ 37 | cudaError_t e = cudaGetLastError(); \ 38 | if (e != cudaSuccess) { \ 39 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 40 | cudaGetErrorString(e)); \ 41 | exit(EXIT_FAILURE); \ 42 | } \ 43 | } 44 | 45 | 46 | void mmap_trainingset_read(std::string &training_file, std::vector& training_set_ids){ 47 | int64_t t_idx = 0; 48 | int32_t fd = open(training_file.c_str(), O_RDONLY); 49 | if(fd == -1){ 50 | std::cout<<"cannout open file: "<& labels){ 143 | int64_t n_idx = 0; 144 | int32_t fd = open(labels_file.c_str(), O_RDONLY); 145 | if(fd == -1){ 146 | std::cout<<"cannout open file: "< 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #define my_ULLONG_MAX (~0ULL) 19 | // #define MAX_ITEMS 16//when setting to 32 will cause compilation error 20 | using namespace std; 21 | 22 | typedef uint64_t app_addr_t; 23 | 24 | // Macro for checking cuda errors following a cuda launch or api call 25 | #define cudaCheckError() \ 26 | { \ 27 | cudaError_t e = cudaGetLastError(); \ 28 | if (e != cudaSuccess) { \ 29 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 30 | cudaGetErrorString(e)); \ 31 | exit(EXIT_FAILURE); \ 32 | } \ 33 | } 34 | 35 | 36 | struct cache_id{ 37 | uint64_t buffer_id; 38 | uint64_t ssd_id; 39 | __forceinline__ 40 | __host__ __device__ 41 | cache_id(){ 42 | buffer_id=my_ULLONG_MAX; 43 | ssd_id=my_ULLONG_MAX; 44 | } 45 | __forceinline__ 46 | __host__ __device__ 47 | cache_id(uint64_t ssd,uint64_t buffer):ssd_id(ssd),buffer_id(buffer){} 48 | 49 | __host__ __device__ 50 | bool operator < (const cache_id& lhs)const{ 51 | return this->ssd_id>>(node_counter, op_id, d_ret_, p_miss_cnt_, input_ids, cache_index, float_feature_len, dst_float_buffer, num_ssd); 109 | cudaMemcpyAsync(p_miss_cnt,p_miss_cnt_,sizeof(int32_t),cudaMemcpyDeviceToDevice, stream); 110 | cudaMemsetAsync(p_miss_cnt_, 0, sizeof(int32_t), stream); 111 | // int32_t* h_miss_cnt = (int32_t*)malloc(sizeof(int32_t)); 112 | // cudaMemcpy(h_miss_cnt, p_miss_cnt, sizeof(int32_t), cudaMemcpyDeviceToHost); 113 | // std::cout<<"op "<>>(node_counter, op_id, d_ret_, p_miss_cnt_, input_ids, cache_index, float_feature_len, dst_float_buffer, num_ssd); 120 | cudaMemcpyAsync(p_miss_cnt,p_miss_cnt_,sizeof(int32_t),cudaMemcpyDeviceToDevice, stream); 121 | cudaMemsetAsync(p_miss_cnt_, 0, sizeof(int32_t), stream); 122 | // int32_t* h_miss_cnt = (int32_t*)malloc(sizeof(int32_t)); 123 | // cudaMemcpy(h_miss_cnt, p_miss_cnt, sizeof(int32_t), cudaMemcpyDeviceToHost); 124 | // std::cout<<"op "< dist(0, 1000000000 - 1); 163 | int32_t input_id = dist(engine); 164 | // int32_t input_id = input_ids[thread_id]; 165 | int32_t cache_idx = cache_index[thread_id]; 166 | if(cache_idx < 0){ 167 | uint64_t offset = atomicAdd(p_miss_cnt, 1); 168 | d_ret[offset].start_lb = uint64_t(input_id % num_ssd) * NUM_LBS_PER_SSD + (input_id % (NUM_LBS_PER_SSD/(feature_block_size)))*feature_block_size;//uint64_t(input_id / num_ssd) * feature_block_size;//raid0 169 | d_ret[offset].num_items = feature_block_size; 170 | for(int j = 0; j < feature_block_size; j++){ 171 | d_ret[offset].dest_addr[j] = (app_addr_t)(dst_float_buffer + (int64_t(node_off) * float_feature_len) + (1ll * thread_id * float_feature_len + j * LBS) / sizeof(float)); 172 | } 173 | } 174 | } 175 | } -------------------------------------------------------------------------------- /training_backend/helper_multiprocess.cpp: -------------------------------------------------------------------------------- 1 | #include "helper_multiprocess.h" 2 | #include 3 | #include 4 | 5 | int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info) 6 | { 7 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 8 | info->size = sz; 9 | info->shmHandle = CreateFileMapping(INVALID_HANDLE_VALUE, 10 | NULL, 11 | PAGE_READWRITE, 12 | 0, 13 | (DWORD)sz, 14 | name); 15 | if (info->shmHandle == 0) { 16 | return GetLastError(); 17 | } 18 | 19 | info->addr = MapViewOfFile(info->shmHandle, FILE_MAP_ALL_ACCESS, 0, 0, sz); 20 | if (info->addr == NULL) { 21 | return GetLastError(); 22 | } 23 | 24 | return 0; 25 | #else 26 | int status = 0; 27 | 28 | info->size = sz; 29 | 30 | info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777); 31 | if (info->shmFd < 0) { 32 | return errno; 33 | } 34 | 35 | status = ftruncate(info->shmFd, sz); 36 | if (status != 0) { 37 | return status; 38 | } 39 | 40 | info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); 41 | if (info->addr == NULL) { 42 | return errno; 43 | } 44 | 45 | return 0; 46 | #endif 47 | } 48 | 49 | int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info) 50 | { 51 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 52 | info->size = sz; 53 | 54 | info->shmHandle = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, name); 55 | if (info->shmHandle == 0) { 56 | return GetLastError(); 57 | } 58 | 59 | info->addr = MapViewOfFile(info->shmHandle, FILE_MAP_ALL_ACCESS, 0, 0, sz); 60 | if (info->addr == NULL) { 61 | return GetLastError(); 62 | } 63 | 64 | return 0; 65 | #else 66 | info->size = sz; 67 | 68 | info->shmFd = shm_open(name, O_RDWR, 0777); 69 | if (info->shmFd < 0) { 70 | return errno; 71 | } 72 | 73 | info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); 74 | if (info->addr == NULL) { 75 | return errno; 76 | } 77 | 78 | return 0; 79 | #endif 80 | } 81 | 82 | void sharedMemoryClose(sharedMemoryInfo *info) 83 | { 84 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 85 | if (info->addr) { 86 | UnmapViewOfFile(info->addr); 87 | } 88 | if (info->shmHandle) { 89 | CloseHandle(info->shmHandle); 90 | } 91 | #else 92 | if (info->addr) { 93 | munmap(info->addr, info->size); 94 | } 95 | if (info->shmFd) { 96 | close(info->shmFd); 97 | } 98 | #endif 99 | } 100 | 101 | int spawnProcess(Process *process, const char *app, char * const *args) 102 | { 103 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 104 | STARTUPINFO si = {0}; 105 | BOOL status; 106 | size_t arglen = 0; 107 | size_t argIdx = 0; 108 | std::string arg_string; 109 | memset(process, 0, sizeof(*process)); 110 | 111 | while (*args) { 112 | arg_string.append(*args).append(1, ' '); 113 | args++; 114 | } 115 | 116 | status = CreateProcess(app, LPSTR(arg_string.c_str()), NULL, NULL, FALSE, 0, NULL, NULL, &si, process); 117 | 118 | return status ? 0 : GetLastError(); 119 | #else 120 | *process = fork(); 121 | if (*process == 0) { 122 | if (0 > execvp(app, args)) { 123 | return errno; 124 | } 125 | } 126 | else if (*process < 0) { 127 | return errno; 128 | } 129 | return 0; 130 | #endif 131 | } 132 | 133 | int waitProcess(Process *process) 134 | { 135 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 136 | DWORD exitCode; 137 | WaitForSingleObject(process->hProcess, INFINITE); 138 | GetExitCodeProcess(process->hProcess, &exitCode); 139 | CloseHandle(process->hProcess); 140 | CloseHandle(process->hThread); 141 | return (int)exitCode; 142 | #else 143 | int status = 0; 144 | do { 145 | if (0 > waitpid(*process, &status, 0)) { 146 | return errno; 147 | } 148 | } while (!WIFEXITED(status)); 149 | return WEXITSTATUS(status); 150 | #endif 151 | } 152 | -------------------------------------------------------------------------------- /training_backend/helper_multiprocess.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-2018 NVIDIA Corporation. All rights reserved. 3 | * 4 | * Please refer to the NVIDIA end user license agreement (EULA) associated 5 | * with this source code for terms and conditions that govern your use of 6 | * this software. Any use, reproduction, disclosure, or distribution of 7 | * this software and related documentation outside the terms of the EULA 8 | * is strictly prohibited. 9 | * 10 | */ 11 | 12 | #ifndef HELPER_MULTIPROCESS_H 13 | #define HELPER_MULTIPROCESS_H 14 | 15 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 16 | #ifndef WIN32_LEAN_AND_MEAN 17 | #define WIN32_LEAN_AND_MEAN 18 | #endif 19 | #include 20 | #else 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #endif 27 | 28 | typedef struct sharedMemoryInfo_st { 29 | void *addr; 30 | size_t size; 31 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 32 | HANDLE shmHandle; 33 | #else 34 | int shmFd; 35 | #endif 36 | } sharedMemoryInfo; 37 | 38 | int sharedMemoryCreate(const char *name, size_t sz, sharedMemoryInfo *info); 39 | 40 | int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info); 41 | 42 | void sharedMemoryClose(sharedMemoryInfo *info); 43 | 44 | 45 | #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) 46 | typedef PROCESS_INFORMATION Process; 47 | #else 48 | typedef pid_t Process; 49 | #endif 50 | 51 | int spawnProcess(Process *process, const char *app, char * const *args); 52 | 53 | int waitProcess(Process *process); 54 | 55 | #endif // HELPER_MULTIPROCESS_H 56 | -------------------------------------------------------------------------------- /training_backend/hyperion_graphsage.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0" 4 | import sys 5 | import tempfile 6 | import argparse 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.multiprocessing as mp 12 | 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | import torch.nn.functional as Func 15 | 16 | import ipc_service 17 | import dgl 18 | from dgl.nn.pytorch import SAGEConv 19 | from dgl.heterograph import DGLBlock 20 | import time 21 | import numpy as np 22 | import torchmetrics 23 | torch.set_printoptions(threshold=np.inf) 24 | 25 | def setup(rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = '12355' 28 | # initialize the process group 29 | if torch.cuda.is_available(): 30 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 31 | else: 32 | dist.init_process_group('gloo', rank=rank, world_size=world_size) 33 | 34 | def cleanup(): 35 | dist.destroy_process_group() 36 | 37 | class SAGE(nn.Module): 38 | def __init__(self, 39 | in_feats, 40 | n_hidden, 41 | n_classes, 42 | n_layers, 43 | activation, 44 | dropout): 45 | super().__init__() 46 | self.n_layers = n_layers 47 | self.n_hidden = n_hidden 48 | self.n_classes = n_classes 49 | self.layers = nn.ModuleList() 50 | self.layers.append(SAGEConv(in_feats, n_hidden, 'mean')) 51 | for _ in range(1, n_layers - 1): 52 | self.layers.append(SAGEConv(n_hidden, n_hidden, 'mean')) 53 | self.layers.append(SAGEConv(n_hidden, n_classes, 'mean')) 54 | self.dropout = nn.Dropout(dropout) 55 | self.activation = activation 56 | 57 | def forward(self, blocks, x): 58 | h = x 59 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): 60 | h = layer(block, h) 61 | if l != len(self.layers) - 1: 62 | h = self.activation(h) 63 | h = self.dropout(h) 64 | return h 65 | 66 | def create_dgl_block(src, dst, num_src_nodes, num_dst_nodes): 67 | gidx = dgl.heterograph_index.create_unitgraph_from_coo(2, num_src_nodes, num_dst_nodes, src, dst, 'coo', row_sorted=True) 68 | g = DGLBlock(gidx, (['_N'], ['_N']), ['_E']) 69 | 70 | return g 71 | 72 | def train_one_step(model, optimizer, loss_fcn, device, feat_len, iter, device_id): 73 | 74 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 75 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 76 | 77 | blocks = [] 78 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 79 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 80 | # print(features[:100]) 81 | # print(ids[:100]) 82 | batch_pred = model(blocks, features) 83 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 84 | loss = loss_fcn(batch_pred, long_labels) 85 | optimizer.zero_grad() 86 | loss.backward() 87 | optimizer.step() 88 | 89 | torch.cuda.synchronize() 90 | ipc_service.synchronize() 91 | return 0 92 | 93 | def valid_one_step(model, metric, device, feat_len): 94 | 95 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 96 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 97 | blocks = [] 98 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 99 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 100 | batch_pred = model(blocks, features) 101 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 102 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 103 | acc = metric(batch_pred, long_labels) 104 | ipc_service.synchronize() 105 | return acc 106 | 107 | def test_one_step(model, metric, device, feat_len): 108 | 109 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst = ipc_service.get_next(feat_len) 110 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num = ipc_service.get_block_size() 111 | blocks = [] 112 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 113 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 114 | batch_pred = model(blocks, features) 115 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device)*0 116 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 117 | acc = metric(batch_pred, long_labels) 118 | ipc_service.synchronize() 119 | return acc 120 | 121 | def worker_process(rank, world_size, args): 122 | print(f"Running GNN Training on CUDA {rank}.") 123 | device_id = rank 124 | setup(rank, world_size) 125 | cuda_device = torch.device("cuda:{}".format(device_id)) 126 | torch.cuda.set_device(cuda_device) 127 | ipc_service.initialize() 128 | train_steps, valid_steps, test_steps = ipc_service.get_steps() 129 | batch_size = (args.train_batch_size) 130 | hop1 = (args.nbrs_num)[0] 131 | hop2 = (args.nbrs_num)[1] 132 | 133 | feat_len = args.features_num 134 | 135 | model = SAGE(in_feats=args.features_num, 136 | n_hidden=args.hidden_dim, 137 | n_classes=args.class_num, 138 | n_layers=args.hops_num, 139 | activation=Func.relu, 140 | dropout=args.drop_rate).to(cuda_device) 141 | 142 | if dist.is_initialized(): 143 | model = DDP(model, device_ids=[device_id]) 144 | loss_fcn = nn.CrossEntropyLoss() 145 | loss_fcn = loss_fcn.to(device_id) 146 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 147 | model.train() 148 | 149 | epoch_num = args.epoch 150 | 151 | for epoch in range(epoch_num): 152 | forward = 0 153 | start = time.time() 154 | epoch_time = 0 155 | for iter in range(train_steps): 156 | train_loss = train_one_step(model, optimizer, loss_fcn, cuda_device, feat_len, iter, device_id) 157 | # if device_id == 0: 158 | # print('Iter {} Train Loss :{} '.format(iter, train_loss)) 159 | epoch_time += time.time() - start 160 | 161 | model.eval() 162 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 163 | metric = metric.to(device_id) 164 | model.metric = metric 165 | with torch.no_grad(): 166 | for iter in range(valid_steps): 167 | valid_one_step(model, metric, cuda_device, feat_len) 168 | acc_val = metric.compute() 169 | if device_id == 0: 170 | print("Epoch:{}, Cost:{} s, Val Acc: {}".format(epoch, epoch_time, acc_val)) 171 | 172 | 173 | model.eval() 174 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 175 | metric = metric.to(device_id) 176 | model.metric = metric 177 | with torch.no_grad(): 178 | for iter in range(test_steps): 179 | test_one_step(model, metric, cuda_device, feat_len) 180 | acc = metric.compute() 181 | if device_id == 0: 182 | print("Accuracy on test data: {}".format(acc)) 183 | metric.reset() 184 | 185 | ipc_service.finalize() 186 | cleanup() 187 | 188 | def run_distribute(dist_fn, world_size, args): 189 | mp.spawn(dist_fn, 190 | args=(world_size, args), 191 | nprocs=world_size, 192 | join=True) 193 | 194 | if __name__ == "__main__": 195 | cur_path = sys.path[0] 196 | argparser = argparse.ArgumentParser("Train GNN.") 197 | argparser.add_argument('--class_num', type=int, default=2) 198 | argparser.add_argument('--features_num', type=int, default=512) 199 | argparser.add_argument('--train_batch_size', type=int, default=8000) 200 | argparser.add_argument('--hidden_dim', type=int, default=256) 201 | argparser.add_argument('--hops_num', type=int, default=2) 202 | argparser.add_argument('--nbrs_num', type=list, default=[25, 10]) 203 | argparser.add_argument('--drop_rate', type=float, default=0.5) 204 | argparser.add_argument('--learning_rate', type=float, default=0.003) 205 | argparser.add_argument('--epoch', type=int, default=2) 206 | argparser.add_argument('--gpu_num', type=int, default=1) 207 | args = argparser.parse_args() 208 | 209 | world_size = args.gpu_num 210 | 211 | run_distribute(worker_process, world_size, args) 212 | -------------------------------------------------------------------------------- /training_backend/hyperion_graphsage3hop.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0" 4 | import sys 5 | import tempfile 6 | import argparse 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.multiprocessing as mp 12 | 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | import torch.nn.functional as Func 15 | 16 | import ipc_service 17 | import dgl 18 | from dgl.nn.pytorch import SAGEConv 19 | from dgl.heterograph import DGLBlock 20 | import time 21 | import numpy as np 22 | import torchmetrics 23 | torch.set_printoptions(threshold=np.inf) 24 | 25 | def setup(rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = '12355' 28 | # initialize the process group 29 | if torch.cuda.is_available(): 30 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 31 | else: 32 | dist.init_process_group('gloo', rank=rank, world_size=world_size) 33 | 34 | def cleanup(): 35 | dist.destroy_process_group() 36 | 37 | class SAGE(nn.Module): 38 | def __init__(self, 39 | in_feats, 40 | n_hidden, 41 | n_classes, 42 | n_layers, 43 | activation, 44 | dropout): 45 | super().__init__() 46 | self.n_layers = n_layers 47 | self.n_hidden = n_hidden 48 | self.n_classes = n_classes 49 | self.layers = nn.ModuleList() 50 | self.layers.append(SAGEConv(in_feats, n_hidden, 'mean')) 51 | for _ in range(1, n_layers - 1): 52 | self.layers.append(SAGEConv(n_hidden, n_hidden, 'mean')) 53 | self.layers.append(SAGEConv(n_hidden, n_classes, 'mean')) 54 | self.dropout = nn.Dropout(dropout) 55 | self.activation = activation 56 | 57 | def forward(self, blocks, x): 58 | h = x 59 | for l, (layer, block) in enumerate(zip(self.layers, blocks)): 60 | h = layer(block, h) 61 | if l != len(self.layers) - 1: 62 | h = self.activation(h) 63 | h = self.dropout(h) 64 | return h 65 | 66 | def create_dgl_block(src, dst, num_src_nodes, num_dst_nodes): 67 | gidx = dgl.heterograph_index.create_unitgraph_from_coo(2, num_src_nodes, num_dst_nodes, src, dst, 'coo', row_sorted=True) 68 | g = DGLBlock(gidx, (['_N'], ['_N']), ['_E']) 69 | 70 | return g 71 | 72 | def train_one_step(model, optimizer, loss_fcn, device, feat_len, iter, device_id): 73 | 74 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst, block3_agg_src, block3_agg_dst = ipc_service.get_next(feat_len) 75 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num, block3_src_num, block3_dst_num = ipc_service.get_block_size() 76 | 77 | blocks = [] 78 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 79 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 80 | blocks.append(create_dgl_block(block3_agg_src, block3_agg_dst, block3_src_num, block3_dst_num)) 81 | # print(features[:100]) 82 | # print(ids[:100]) 83 | batch_pred = model(blocks, features) 84 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 85 | loss = loss_fcn(batch_pred, long_labels) 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | 90 | torch.cuda.synchronize() 91 | ipc_service.synchronize() 92 | return 0 93 | 94 | def valid_one_step(model, metric, device, feat_len): 95 | 96 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst, block3_agg_src, block3_agg_dst = ipc_service.get_next(feat_len) 97 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num, block3_src_num, block3_dst_num = ipc_service.get_block_size() 98 | 99 | blocks = [] 100 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 101 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 102 | blocks.append(create_dgl_block(block3_agg_src, block3_agg_dst, block3_src_num, block3_dst_num)) 103 | batch_pred = model(blocks, features) 104 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 105 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 106 | acc = metric(batch_pred, long_labels) 107 | ipc_service.synchronize() 108 | return acc 109 | 110 | def test_one_step(model, metric, device, feat_len): 111 | 112 | ids, features, labels, block1_agg_src, block1_agg_dst, block2_agg_src, block2_agg_dst, block3_agg_src, block3_agg_dst = ipc_service.get_next(feat_len) 113 | block1_src_num, block1_dst_num, block2_src_num, block2_dst_num, block3_src_num, block3_dst_num = ipc_service.get_block_size() 114 | 115 | blocks = [] 116 | blocks.append(create_dgl_block(block1_agg_src, block1_agg_dst, block1_src_num, block1_dst_num)) 117 | blocks.append(create_dgl_block(block2_agg_src, block2_agg_dst, block2_src_num, block2_dst_num)) 118 | blocks.append(create_dgl_block(block3_agg_src, block3_agg_dst, block3_src_num, block3_dst_num)) 119 | batch_pred = model(blocks, features) 120 | long_labels = torch.as_tensor(labels, dtype=torch.long, device=device) 121 | batch_pred = torch.softmax(batch_pred, dim=1).to(device) 122 | acc = metric(batch_pred, long_labels) 123 | ipc_service.synchronize() 124 | return acc 125 | 126 | def worker_process(rank, world_size, args): 127 | print(f"Running GNN Training on CUDA {rank}.") 128 | device_id = rank 129 | setup(rank, world_size) 130 | cuda_device = torch.device("cuda:{}".format(device_id)) 131 | torch.cuda.set_device(cuda_device) 132 | ipc_service.initialize() 133 | train_steps, valid_steps, test_steps = ipc_service.get_steps() 134 | batch_size = (args.train_batch_size) 135 | hop1 = (args.nbrs_num)[0] 136 | hop2 = (args.nbrs_num)[1] 137 | 138 | feat_len = args.features_num 139 | 140 | model = SAGE(in_feats=args.features_num, 141 | n_hidden=args.hidden_dim, 142 | n_classes=args.class_num, 143 | n_layers=args.hops_num, 144 | activation=Func.relu, 145 | dropout=args.drop_rate).to(cuda_device) 146 | 147 | if dist.is_initialized(): 148 | model = DDP(model, device_ids=[device_id]) 149 | loss_fcn = nn.CrossEntropyLoss() 150 | loss_fcn = loss_fcn.to(device_id) 151 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 152 | model.train() 153 | 154 | epoch_num = args.epoch 155 | 156 | for epoch in range(epoch_num): 157 | forward = 0 158 | start = time.time() 159 | epoch_time = 0 160 | for iter in range(train_steps): 161 | train_loss = train_one_step(model, optimizer, loss_fcn, cuda_device, feat_len, iter, device_id) 162 | # if device_id == 0: 163 | # print('Iter {} Train Loss :{} '.format(iter, train_loss)) 164 | epoch_time += time.time() - start 165 | 166 | model.eval() 167 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 168 | metric = metric.to(device_id) 169 | model.metric = metric 170 | with torch.no_grad(): 171 | for iter in range(valid_steps): 172 | valid_one_step(model, metric, cuda_device, feat_len) 173 | acc_val = metric.compute() 174 | if device_id == 0: 175 | print("Epoch:{}, Cost:{} s, Val Acc: {}".format(epoch, epoch_time, acc_val)) 176 | 177 | 178 | model.eval() 179 | metric = torchmetrics.Accuracy('multiclass', num_classes = args.class_num) 180 | metric = metric.to(device_id) 181 | model.metric = metric 182 | with torch.no_grad(): 183 | for iter in range(test_steps): 184 | test_one_step(model, metric, cuda_device, feat_len) 185 | acc = metric.compute() 186 | if device_id == 0: 187 | print("Accuracy on test data: {}".format(acc)) 188 | metric.reset() 189 | 190 | ipc_service.finalize() 191 | cleanup() 192 | 193 | def run_distribute(dist_fn, world_size, args): 194 | mp.spawn(dist_fn, 195 | args=(world_size, args), 196 | nprocs=world_size, 197 | join=True) 198 | 199 | if __name__ == "__main__": 200 | cur_path = sys.path[0] 201 | argparser = argparse.ArgumentParser("Train GNN.") 202 | argparser.add_argument('--class_num', type=int, default=2) 203 | argparser.add_argument('--features_num', type=int, default=1024) 204 | argparser.add_argument('--train_batch_size', type=int, default=1000) 205 | argparser.add_argument('--hidden_dim', type=int, default=256) 206 | argparser.add_argument('--hops_num', type=int, default=3) 207 | argparser.add_argument('--nbrs_num', type=list, default=[25, 10, 5]) 208 | argparser.add_argument('--drop_rate', type=float, default=0.5) 209 | argparser.add_argument('--learning_rate', type=float, default=0.003) 210 | argparser.add_argument('--epoch', type=int, default=2) 211 | argparser.add_argument('--gpu_num', type=int, default=1) 212 | args = argparser.parse_args() 213 | 214 | world_size = args.gpu_num 215 | 216 | run_distribute(worker_process, world_size, args) 217 | -------------------------------------------------------------------------------- /training_backend/ipc_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "helper_multiprocess.h" 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include "ipc_service.h" 12 | 13 | #include 14 | #include 15 | 16 | #define MAX_DEVICE 8 17 | #define MEMORY_USAGE 7 18 | 19 | // Macro for checking cuda errors following a cuda launch or api call 20 | #define cudaCheckError() \ 21 | { \ 22 | cudaError_t e = cudaGetLastError(); \ 23 | if (e != cudaSuccess) { \ 24 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ 25 | cudaGetErrorString(e)); \ 26 | exit(EXIT_FAILURE); \ 27 | } \ 28 | } 29 | 30 | typedef struct shmStruct_st { 31 | int32_t steps[3]; 32 | cudaIpcMemHandle_t memHandle[MAX_DEVICE][INTERBATCH_CON][MEMORY_USAGE]; 33 | } shmStruct; 34 | 35 | class GPUIPCEnv : public IPCEnv { 36 | public: 37 | int Initialize() override { 38 | volatile shmStruct *shm = NULL; 39 | int central_device = -1; 40 | cudaGetDevice(¢ral_device); 41 | cudaCheckError(); 42 | sharedMemoryInfo info; 43 | const char shmName[] = "simpleIPCshm"; 44 | if (sharedMemoryCreate(shmName, sizeof(*shm), &info) != 0) { 45 | printf("Failed to create shared memory slab\n"); 46 | exit(EXIT_FAILURE); 47 | } 48 | 49 | shm = (volatile shmStruct *)info.addr; 50 | train_step_ = shm->steps[0]; 51 | valid_step_ = shm->steps[1]; 52 | test_step_ = shm->steps[2]; 53 | ids_.resize(INTERBATCH_CON); 54 | float_features_.resize(INTERBATCH_CON); 55 | labels_.resize(INTERBATCH_CON); 56 | agg_src_.resize(INTERBATCH_CON); 57 | agg_dst_.resize(INTERBATCH_CON); 58 | node_counter_.resize(INTERBATCH_CON); 59 | edge_counter_.resize(INTERBATCH_CON); 60 | 61 | for(int i = 0; i < INTERBATCH_CON; i++){ 62 | cudaIpcOpenMemHandle(&ids_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][0], cudaIpcMemLazyEnablePeerAccess); 63 | cudaIpcOpenMemHandle(&float_features_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][1], cudaIpcMemLazyEnablePeerAccess); 64 | cudaIpcOpenMemHandle(&labels_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][2], cudaIpcMemLazyEnablePeerAccess); 65 | cudaIpcOpenMemHandle(&agg_src_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][3], cudaIpcMemLazyEnablePeerAccess); 66 | cudaIpcOpenMemHandle(&agg_dst_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][4], cudaIpcMemLazyEnablePeerAccess); 67 | cudaIpcOpenMemHandle(&node_counter_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][5], cudaIpcMemLazyEnablePeerAccess); 68 | cudaIpcOpenMemHandle(&edge_counter_[i], *(cudaIpcMemHandle_t*)&shm->memHandle[central_device][i][6], cudaIpcMemLazyEnablePeerAccess); 69 | cudaCheckError(); 70 | } 71 | std::cout<<"CUDA: "< ids_; 157 | std::vector float_features_; 158 | std::vector labels_; 159 | std::vector agg_src_; 160 | std::vector agg_dst_; 161 | std::vector node_counter_; 162 | std::vector edge_counter_; 163 | std::vector semw_; 164 | std::vector semr_; 165 | 166 | int32_t train_step_; 167 | int32_t valid_step_; 168 | int32_t test_step_; 169 | int current_pipe_; 170 | }; 171 | IPCEnv* NewIPCEnv(){ 172 | return new GPUIPCEnv(); 173 | } 174 | 175 | // Define the GPU implementation that launches the CUDA kernel. 176 | 177 | std::vector cuda_get_next( 178 | int32_t* ids, 179 | float* float_features, 180 | int32_t* labels, 181 | int feature_dim, 182 | int32_t* agg_src, 183 | int32_t* agg_dst, 184 | int32_t* node_counter, 185 | int32_t* edge_counter, 186 | int32_t* h_node_counter, 187 | int32_t* h_edge_counter 188 | ){ 189 | int current_dev = -1; 190 | cudaGetDevice(¤t_dev); 191 | auto device = "cuda:" + std::to_string(current_dev); 192 | cudaCheckError(); 193 | 194 | cudaMemcpy(h_node_counter, node_counter, 16 * sizeof(int32_t), cudaMemcpyDeviceToHost); 195 | cudaMemcpy(h_edge_counter, edge_counter, 16 * sizeof(int32_t), cudaMemcpyDeviceToHost); 196 | int hop_num = h_node_counter[INTRABATCH_CON * 3 - 1]; 197 | // std::cout<<"hop num "< ret; 200 | 201 | torch::Tensor ids_tensor = torch::from_blob( 202 | ids, 203 | {(long long)h_node_counter[INTRABATCH_CON * 3 + hop_num]}, 204 | torch::TensorOptions().dtype(torch::kI32).device(device)); 205 | 206 | ret.push_back(ids_tensor); 207 | 208 | torch::Tensor feature_tensor = torch::from_blob( 209 | float_features, 210 | {(long long)(h_node_counter[INTRABATCH_CON * 3 + hop_num]), (long long)(feature_dim)}, 211 | torch::TensorOptions().dtype(torch::kF32).device(device)); 212 | 213 | ret.push_back(feature_tensor); 214 | 215 | torch::Tensor labels_tensor = torch::from_blob( 216 | labels, 217 | {(long long)h_node_counter[INTRABATCH_CON * 3]}, 218 | torch::TensorOptions().dtype(torch::kI32).device(device)); 219 | 220 | ret.push_back(labels_tensor); 221 | 222 | for(int i = hop_num; i > 0; i--){ 223 | torch::Tensor agg_src_tensor = torch::from_blob( 224 | agg_src, 225 | {(long long)h_edge_counter[INTRABATCH_CON * 3 + i]}, 226 | torch::TensorOptions().dtype(torch::kI32).device(device)); 227 | torch::Tensor agg_dst_tensor = torch::from_blob( 228 | agg_dst, 229 | {(long long)h_edge_counter[INTRABATCH_CON * 3 + i]}, 230 | torch::TensorOptions().dtype(torch::kI32).device(device)); 231 | ret.push_back(agg_src_tensor); 232 | ret.push_back(agg_dst_tensor); 233 | } 234 | 235 | return ret; 236 | } 237 | -------------------------------------------------------------------------------- /training_backend/ipc_service.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "ipc_service.h" 9 | 10 | 11 | IPCEnv* env; 12 | int32_t h_node_counter[16]; 13 | int32_t h_edge_counter[16]; 14 | 15 | void InitializeIPC(){ 16 | env = NewIPCEnv(); 17 | env->Initialize(); 18 | } 19 | 20 | void FinalizeIPC(){ 21 | env->Finalize(); 22 | } 23 | 24 | std::vector cuda_get_next( 25 | int32_t* ids, 26 | float* float_features, 27 | int32_t* labels, 28 | int feature_dim, 29 | int32_t* agg_src, 30 | int32_t* agg_dst, 31 | int32_t* node_counter, 32 | int32_t* edge_counter, 33 | int32_t* h_node_counter, 34 | int32_t* h_edge_counter 35 | ); 36 | 37 | // C++ interface 38 | 39 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 40 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 41 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 42 | 43 | 44 | std::vector get_next(int feature_dim) { 45 | env->Wait(); 46 | int32_t* ids = env->GetIds(); 47 | float* float_features = env->GetFloatFeatures(); 48 | int32_t* labels = env->GetLabels(); 49 | int32_t* agg_src = env->GetAggSrc(); 50 | int32_t* agg_dst = env->GetAggDst(); 51 | int32_t* node_counter = env->GetNodeCounter(); 52 | int32_t* edge_counter = env->GetEdgeCounter(); 53 | auto result = cuda_get_next(ids, float_features, labels, 54 | feature_dim, 55 | agg_src, agg_dst, 56 | node_counter, edge_counter, 57 | h_node_counter, h_edge_counter); 58 | return result; 59 | } 60 | 61 | std::vector get_block_size() { 62 | std::vector ret; 63 | int hop_num = h_node_counter[INTRABATCH_CON * 3 - 1]; 64 | 65 | for(int i = hop_num; i > 0; i--){ 66 | ret.push_back(h_node_counter[INTRABATCH_CON * 3 + i]); 67 | ret.push_back(h_node_counter[INTRABATCH_CON * 3 + i - 1]); 68 | } 69 | // int block1_src_node = h_node_counter[9]; 70 | // int block1_dst_node = h_node_counter[7]; 71 | // int block2_src_node = h_node_counter[7]; 72 | // int block2_dst_node = h_node_counter[5]; 73 | 74 | // ret.push_back(block1_src_node); 75 | // ret.push_back(block1_dst_node); 76 | // ret.push_back(block2_src_node); 77 | // ret.push_back(block2_dst_node); 78 | return ret; 79 | } 80 | 81 | std::vector get_steps(){ 82 | std::vector ret; 83 | ret.push_back(env->GetTrainStep()); 84 | ret.push_back(env->GetValidStep()); 85 | ret.push_back(env->GetTestStep()); 86 | return ret; 87 | } 88 | 89 | void Synchronize(){ 90 | env->Post(); 91 | } 92 | 93 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 94 | m.def("get_next", &get_next, "dataset get next (CUDA)"); 95 | m.def("get_block_size", &get_block_size, "get dgl block size(CUDA)"); 96 | m.def("get_steps", &get_steps, "get steps(CUDA)"); 97 | m.def("initialize", &InitializeIPC, "InitializeIPC (CUDA)"); 98 | m.def("finalize", &FinalizeIPC, "FinalizeIPC (CUDA)"); 99 | m.def("synchronize", &Synchronize, "synchronize (CUDA)"); 100 | } -------------------------------------------------------------------------------- /training_backend/ipc_service.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define INTRABATCH_CON 3 4 | #define INTERBATCH_CON 2 5 | 6 | class IPCEnv { 7 | public: 8 | virtual int Initialize() = 0; 9 | virtual int32_t* GetIds() = 0; 10 | virtual float* GetFloatFeatures() = 0; 11 | virtual int32_t* GetLabels() = 0; 12 | virtual int32_t* GetAggSrc() = 0; 13 | virtual int32_t* GetAggDst() = 0; 14 | virtual int32_t* GetNodeCounter() = 0; 15 | virtual int32_t* GetEdgeCounter() = 0; 16 | virtual int32_t GetTrainStep() = 0; 17 | virtual int32_t GetValidStep() = 0; 18 | virtual int32_t GetTestStep() = 0; 19 | virtual void Finalize() = 0; 20 | virtual void Wait() = 0; 21 | virtual void Post() = 0; 22 | }; 23 | IPCEnv* NewIPCEnv(); -------------------------------------------------------------------------------- /training_backend/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | import os 4 | os.environ['CUDA_HOME'] = '/usr/local/cuda' 5 | setup( 6 | name='ipcservice', 7 | ext_modules=[ 8 | CUDAExtension('ipc_service', [ 9 | 'ipc_service.cpp', 10 | 'helper_multiprocess.cpp', 11 | 'ipc_cuda_kernel.cu', 12 | ]) 13 | ], 14 | cmdclass={ 15 | 'build_ext': BuildExtension 16 | }) -------------------------------------------------------------------------------- /unload_ssd.py: -------------------------------------------------------------------------------- 1 | from os import popen, system 2 | import getpass 3 | 4 | ssds = [] 5 | 6 | if not getpass.getuser() == "root": 7 | print("\033[41mERROR: Please run this script as root!\033[0m") 8 | quit(1) 9 | 10 | f = popen("lspci | grep Non-Volatile") 11 | lines = f.readlines() 12 | for line in lines: 13 | ssds.append(line.split(" ")[0]) 14 | 15 | for idx, ssd in enumerate(ssds): 16 | if idx != 0: 17 | continue 18 | # if idx == 1 or idx == 8 or idx == 6 or idx == 7: 19 | # continue 20 | # if idx == 3: 21 | # continue 22 | # if idx == 13: 23 | # continue 24 | # 4-3 25 | # if idx == 1 or idx == 2 or idx == 3 or idx == 8: 26 | # continue 27 | if idx >= 12: 28 | continue 29 | tmp = ssd.replace(":", "\\:") 30 | system(f"""sh -c 'echo -n "0000:{ssd}" > /sys/bus/pci/devices/0000\:{tmp}/driver/unbind'""") 31 | 32 | --------------------------------------------------------------------------------