├── fuseGNN ├── testbench │ ├── __init__.py │ ├── gcn_conv_tb.py │ └── gat_conv_tb.py ├── dataloader │ ├── reddit.py │ ├── __init__.py │ ├── citations.py │ ├── graph_kernel.py │ └── dataset.py ├── modules │ ├── __init__.py │ ├── gcn.py │ └── gat.py ├── utils │ ├── __init__.py │ ├── cudaprofile.py │ ├── lr_schedular.py │ └── logger.py ├── convs │ ├── __init__.py │ ├── gcn_conv.py │ └── gat_conv.py └── functional │ ├── __init__.py │ ├── dropout.py │ ├── format.py │ ├── gcn.py │ ├── aggregate.py │ └── gat.py ├── src ├── install.sh ├── cuda │ ├── format.cpp │ ├── gcn.cpp │ ├── gat.cpp │ ├── aggregate.cpp │ ├── format_kernel.cu │ ├── gcn_kernel.cu │ ├── gat_kernel.cu │ └── aggregate_kernel.cu └── setup.py ├── README.md └── training_main.py /fuseGNN/testbench/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fuseGNN/dataloader/reddit.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import Reddit 2 | -------------------------------------------------------------------------------- /src/install.sh: -------------------------------------------------------------------------------- 1 | pip uninstall fuseGNN 2 | python -W ignore setup.py build 3 | python -W ignore setup.py install -------------------------------------------------------------------------------- /fuseGNN/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from fuseGNN.modules.gcn import GCN, gcn_config 2 | from fuseGNN.modules.gat import GAT, gat_config 3 | -------------------------------------------------------------------------------- /fuseGNN/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from fuseGNN.utils.lr_schedular import LrSchedular 2 | from fuseGNN.utils.logger import Logger 3 | from fuseGNN.utils.cudaprofile import start, stop -------------------------------------------------------------------------------- /fuseGNN/convs/__init__.py: -------------------------------------------------------------------------------- 1 | from fuseGNN.convs.gcn_conv import geoGCNConv, refGCNConv, garGCNConv, gasGCNConv 2 | from fuseGNN.convs.gat_conv import geoGATConv, refGATConv, garGATConv, gasGATConv 3 | -------------------------------------------------------------------------------- /fuseGNN/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from fuseGNN.dataloader.citations import Citations 2 | from fuseGNN.dataloader.reddit import Reddit 3 | from fuseGNN.dataloader.graph_kernel import graph_kernel_dataset -------------------------------------------------------------------------------- /fuseGNN/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from fuseGNN.functional.format import coo2csr, csr2csc 2 | from fuseGNN.functional.gcn import gcn_gar_edge_weight, gcn_gas_edge_weight 3 | from fuseGNN.functional.gat import gat_gar_edge_weight, gat_gas_edge_weight 4 | from fuseGNN.functional.aggregate import fused_gar_agg, fused_gas_agg 5 | from fuseGNN.functional.dropout import Dropout 6 | 7 | dropout = Dropout.apply -------------------------------------------------------------------------------- /fuseGNN/utils/cudaprofile.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | 3 | _cudart = ctypes.CDLL('libcudart.so') 4 | 5 | 6 | def start(): 7 | # As shown at http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__PROFILER.html, 8 | # the return value will unconditionally be 0. This check is just in case it changes in 9 | # the future. 10 | ret = _cudart.cudaProfilerStart() 11 | if ret != 0: 12 | raise Exception("cudaProfilerStart() returned %d" % ret) 13 | 14 | def stop(): 15 | ret = _cudart.cudaProfilerStop() 16 | if ret != 0: 17 | raise Exception("cudaProfilerStop() returned %d" % ret) -------------------------------------------------------------------------------- /fuseGNN/utils/lr_schedular.py: -------------------------------------------------------------------------------- 1 | class LrSchedular: 2 | def __init__(self, init_lr, mode, **kwargs): 3 | self.lr = init_lr 4 | self.mode = mode 5 | assert self.mode in ['constant', 'step_decay'], "Only 'constant' and 'step_decay' are supported" 6 | if mode == 'step_decay': 7 | self.interval = kwargs['interval'] 8 | self.rate = kwargs['rate'] 9 | 10 | def update(self, epoch, optimizer): 11 | if self.mode == 'step_decay': 12 | if epoch % self.interval == 0: 13 | self.lr = self.lr * self.rate 14 | for param_group in optimizer.param_groups: 15 | param_group['lr'] = self.lr 16 | -------------------------------------------------------------------------------- /fuseGNN/functional/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gcnlib_cuda 3 | 4 | # gcnlib_cuda.dropout(cooVals, 0.5, True) 5 | 6 | class Dropout(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, input_, rate, training): 9 | # what feed in is the probability of 0. 10 | out, mask = gcnlib_cuda.dropout(input_, rate, training, None) 11 | if not training: 12 | mask = torch.ones_like(input_) 13 | if training: 14 | ctx.save_for_backward(mask) 15 | return out, mask 16 | 17 | @staticmethod 18 | def backward(ctx, grad_out, grad_mask): 19 | mask = ctx.saved_tensors[0] 20 | return gcnlib_cuda.dropout_bp(grad_out, mask), None, None 21 | -------------------------------------------------------------------------------- /src/cuda/format.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | std::vector csr2csc_cuda( 6 | torch::Tensor inPtr, 7 | torch::Tensor inInd, 8 | torch::Tensor inVal, 9 | int num_row 10 | ); 11 | 12 | 13 | std::vector csr2csc( 14 | torch::Tensor inPtr, 15 | torch::Tensor inInd, 16 | torch::Tensor inVal, 17 | int num_row 18 | ){ 19 | return csr2csc_cuda(inPtr, inInd, inVal, num_row); 20 | } 21 | 22 | torch::Tensor coo2csr_cuda( 23 | torch::Tensor cooRowInd, 24 | int num_row 25 | ); 26 | 27 | torch::Tensor coo2csr( 28 | torch::Tensor cooRowInd, 29 | int num_row 30 | ){ 31 | return coo2csr_cuda(cooRowInd, num_row); 32 | } 33 | 34 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 35 | m.def("csr2csc", &csr2csc, "Converter between CSC and CSR"); 36 | m.def("coo2csr", &coo2csr, "Convert COO to CSR"); 37 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fuseGNN 2 | 3 | This is the source code for the paper "fuseGNN: Accelerating Graph Convolutional Neural Network Training on GPGPU". 4 | 5 | ## Organization of the project 6 | 7 | * The "src" folder contains all the CUDA kernels we developed. 8 | * The "fuseGNN" folder contains all the python APIs in different level: functional -> convs -> modules. 9 | 10 | ## Constraint random test 11 | 12 | Under directory "fuseGNN/testbench", we provide two random testbenchs for a single layer GCN and GAT for both forward and backward passes. 13 | 14 | ## Training 15 | 16 | The training of GCN and GAT on different datasets can be launched with "training_main.py". Different implementations can be selected with the argument "--mode", in particular, "geo" for pytorch geometric, gas for our fused-GAS abstraction, and gar for our fused-GAR abstraction. For GAT, we demonstrate the single attention head scenario, while the multi-attention head can be implemented by slightly modifying the related cuda kernels. 17 | -------------------------------------------------------------------------------- /fuseGNN/utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tensorboardX import SummaryWriter 4 | 5 | 6 | class Logger: 7 | """ 8 | Record the results and store them to the json file 9 | """ 10 | def __init__(self, model_, data_, log_dir): 11 | if not os.path.exists(log_dir): 12 | os.mkdir(log_dir) 13 | self.log = { 14 | 'model': model_, 15 | 'data': data_, 16 | 'train_loss': [], 17 | 'train_acc': [], 18 | 'val_acc': [], 19 | 'test_acc': [], 20 | 'epoch': [], 21 | } 22 | self.exp_name = model_ + '_' + data_ 23 | json_dir = log_dir + self.exp_name 24 | self.writer = SummaryWriter(log_dir=json_dir) 25 | if not os.path.exists(json_dir): 26 | os.mkdir(json_dir) 27 | self.log_file = json_dir + 'log.json' 28 | 29 | def add_scalar(self, key, value, epoch): 30 | self.log[key].append(value) 31 | self.writer.add_scalar(tag=key, scalar_value=value, global_step=epoch) 32 | 33 | def write(self): 34 | with open(self.log_file, 'w') as file: 35 | json.dump(self.log, file) 36 | self.writer.close() 37 | -------------------------------------------------------------------------------- /src/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CppExtension, BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='fuseGNN', 6 | version='0.0.1', 7 | description='Custom library for graph convolutional networks for pytorch', 8 | author='Zhaodong Chen', 9 | author_email='chenzd15thu@ucsb.edu', 10 | ext_modules=[ 11 | CUDAExtension('fgnn_agg', 12 | ['cuda/aggregate.cpp', 'cuda/aggregate_kernel.cu'], 13 | extra_compile_args={'cxx':[], 'nvcc':['-arch=sm_70']}), 14 | CUDAExtension('fgnn_format', 15 | ['cuda/format.cpp', 'cuda/format_kernel.cu'], 16 | extra_compile_args={'cxx':[], 'nvcc':['-arch=sm_70', '-lcusparse']}), 17 | CUDAExtension('fgnn_gcn', 18 | ['cuda/gcn.cpp', 'cuda/gcn_kernel.cu'], 19 | extra_compile_args={'cxx':[], 'nvcc':['-arch=sm_70']}), 20 | CUDAExtension('fgnn_gat', 21 | ['cuda/gat.cpp', 'cuda/gat_kernel.cu'], 22 | extra_compile_args={'cxx':[], 'nvcc':['-arch=sm_70']}), 23 | ], 24 | cmdclass={'build_ext': BuildExtension}, 25 | install_requires=['torch'] 26 | ) 27 | -------------------------------------------------------------------------------- /fuseGNN/functional/format.py: -------------------------------------------------------------------------------- 1 | import fgnn_format 2 | import torch 3 | 4 | 5 | class Coo2Csr(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, tar_index, num_node): 8 | tar_ptr = fgnn_format.coo2csr(tar_index, num_node) 9 | return tar_ptr 10 | 11 | @staticmethod 12 | def backward(ctx, grad_tar_ptr): 13 | return None, None 14 | 15 | 16 | def coo2csr(src_index, tar_index, num_node, edge_weight=None, sorted=False): 17 | if not sorted: 18 | tar_index, indices = torch.sort(tar_index, dim=0) 19 | src_index = torch.gather(src_index, 0, indices) 20 | if edge_weight is not None: 21 | edge_weight = torch.gather(edge_weight.squeeze(), 0, indices) 22 | tar_ptr = Coo2Csr.apply(tar_index, num_node) 23 | return src_index, tar_index, tar_ptr, edge_weight 24 | 25 | 26 | 27 | class Csr2Csc(torch.autograd.Function): 28 | @staticmethod 29 | def forward(ctx, inPtr, inInd, inVal, num_row): 30 | outPtr, outInd, outVal = fgnn_format.csr2csc(inPtr, inInd, inVal, num_row) 31 | return outPtr, outInd, outVal 32 | 33 | @staticmethod 34 | def backward(ctx, grad_outPtr, grad_outInd, grad_outVal): 35 | return None, None, None 36 | 37 | 38 | csr2csc = Csr2Csc.apply -------------------------------------------------------------------------------- /src/cuda/gcn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | std::vector gcn_gar_egde_weight_cuda( 6 | torch::Tensor src_index, 7 | torch::Tensor tar_ptr, 8 | torch::Tensor tar_index, 9 | int num_nodes, 10 | torch::optional optional_edge_weight, 11 | bool tar_to_src 12 | ); 13 | 14 | std::vector gcn_gar_egde_weight( 15 | torch::Tensor src_index, 16 | torch::Tensor tar_ptr, 17 | torch::Tensor tar_index, 18 | int num_nodes, 19 | torch::optional optional_edge_weight, 20 | bool tar_to_src 21 | ){ 22 | return gcn_gar_egde_weight_cuda(src_index, tar_ptr, tar_index, num_nodes, optional_edge_weight, tar_to_src); 23 | } 24 | 25 | std::vector gcn_gas_edge_weight_cuda( 26 | torch::Tensor src_index, 27 | torch::Tensor tar_index, 28 | int num_nodes, 29 | torch::optional optional_edge_weight, 30 | bool tar_to_src 31 | ); 32 | 33 | 34 | std::vector gcn_gas_edge_weight( 35 | torch::Tensor src_index, 36 | torch::Tensor tar_index, 37 | int num_nodes, 38 | torch::optional optional_edge_weight, 39 | bool tar_to_src 40 | ){ 41 | return gcn_gas_edge_weight_cuda(src_index, tar_index, num_nodes, optional_edge_weight, tar_to_src); 42 | } 43 | 44 | 45 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 46 | m.def("gcn_gar_edge_weight", &gcn_gar_egde_weight, "gcn_gar_edge_weight"); 47 | m.def("gcn_gas_edge_weight", &gcn_gas_edge_weight, "gcn_gas_edge_weight"); 48 | } -------------------------------------------------------------------------------- /fuseGNN/dataloader/citations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import InMemoryDataset, download_url 3 | from torch_geometric.io import read_planetoid_data 4 | 5 | class Citations(InMemoryDataset): 6 | """ 7 | The citation network dtatasets "Cora", "CiteSeer", and "PubMed" 8 | Nodes represent documents and edges represent citation links 9 | Training, validation and test splits are given by binary masks. 10 | 11 | The edges are undirected and the edge weights are binary 12 | """ 13 | url = 'https://github.com/kimiyoung/planetoid/raw/master/data' 14 | 15 | def __init__(self, root, name, transform=None, pre_transform=None): 16 | self.name = name 17 | super(Citations, self).__init__(root, transform, pre_transform) 18 | self.data, self.slices = torch.load(self.processed_paths[0]) 19 | 20 | @property 21 | def raw_file_names(self): 22 | names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index'] 23 | return ['ind.{}.{}'.format(self.name.lower(), name) for name in names] 24 | 25 | @property 26 | def processed_file_names(self): 27 | return 'data.pt' 28 | 29 | def download(self): 30 | for name in self.raw_file_names: 31 | download_url('{}/{}'.format(self.url, name), self.raw_dir) 32 | 33 | def process(self): 34 | data = read_planetoid_data(self.raw_dir, self.name) 35 | data = data if self.pre_transform is None else self.pre_transform(data) 36 | torch.save(self.collate([data]), self.processed_paths[0]) 37 | 38 | def __repr__(self): 39 | return '{}()'.format(self.name) 40 | -------------------------------------------------------------------------------- /fuseGNN/functional/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fgnn_gcn 3 | 4 | # fused get edge weight function for GAR model 5 | 6 | class Inv(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, degree): 9 | self_edge_weight = 1./degree 10 | return self_edge_weight 11 | @staticmethod 12 | def backward(ctx, gred_self_edge_weight): 13 | return None 14 | 15 | class GCNGAREdgeWeight(torch.autograd.Function): 16 | @staticmethod 17 | def forward(ctx, src_index, tar_ptr, tar_index, num_nodes, edge_weight, flow): 18 | edge_weight, degree = fgnn_gcn.gcn_gar_edge_weight(src_index, tar_ptr, tar_index, num_nodes, edge_weight, 19 | flow=='target_to_source') 20 | return edge_weight, degree 21 | 22 | @staticmethod 23 | def backward(ctx, grad_edge_weight, grad_degree): 24 | return None, None, None, None, None, None 25 | 26 | def gcn_gar_edge_weight(src_index, tar_ptr, tar_index, num_nodes, edge_weight, flow): 27 | edge_weight, degree = GCNGAREdgeWeight.apply(src_index, tar_ptr, tar_index, num_nodes, edge_weight, flow) 28 | self_edge_weight = Inv.apply(degree) 29 | return edge_weight, self_edge_weight 30 | 31 | # fused get edge weight function for GAS model 32 | 33 | class GCNGASEdgeWeight(torch.autograd.Function): 34 | @staticmethod 35 | def forward(ctx, src_index, tar_index, num_nodes, edge_weight, flow): 36 | weight_to_cache, degree_to_cache = fgnn_gcn.gcn_gas_edge_weight(src_index, tar_index, num_nodes, edge_weight, 37 | flow=='target_to_source') 38 | return weight_to_cache, degree_to_cache 39 | 40 | @staticmethod 41 | def backward(ctx, grad_edge_weight, grad_degree): 42 | return None, None, None, None, None, None 43 | 44 | def gcn_gas_edge_weight(src_index, tar_index, num_nodes, edge_weight, flow): 45 | edge_weight, degree = GCNGASEdgeWeight.apply(src_index, tar_index, num_nodes, edge_weight, flow) 46 | self_edge_weight = Inv.apply(degree) 47 | return edge_weight, self_edge_weight 48 | -------------------------------------------------------------------------------- /fuseGNN/dataloader/graph_kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.datasets import TUDataset 3 | from torch_geometric.utils import degree 4 | import torch_geometric.transforms as T 5 | from torch_geometric.data import DataLoader 6 | 7 | 8 | class NormalizedDegree(object): 9 | def __init__(self, mean, std): 10 | self.mean = mean 11 | self.std = std 12 | 13 | def __call__(self, data): 14 | deg = degree(data.edge_index[0], dtype=torch.float) 15 | deg = (deg - self.mean) / self.std 16 | data.x = deg.view(-1, 1) 17 | return data 18 | 19 | 20 | def graph_kernel_dataset(name, path, sparse=True): 21 | dataset = TUDataset(path, name) 22 | dataset.data.edge_attr = None 23 | 24 | if dataset.data.x is None: 25 | max_degree = 0 26 | degs = [] 27 | for data in dataset: 28 | degs += [degree(data.edge_index[0], dtype=torch.long)] 29 | max_degree = max(max_degree, degs[-1].max().item()) 30 | 31 | if max_degree < 1000: 32 | dataset.transform = T.OneHotDegree(max_degree) 33 | else: 34 | deg = torch.cat(degs, dim=0).to(torch.float) 35 | mean, std = deg.mean().item(), deg.std().item() 36 | dataset.transform = NormalizedDegree(mean, std) 37 | 38 | if not sparse: 39 | num_nodes = max_num_nodes = 0 40 | for data in dataset: 41 | num_nodes += data.num_nodes 42 | max_num_nodes = max(data.num_nodes, max_num_nodes) 43 | 44 | # Filter out a few really large graphs in order to apply DiffPool. 45 | if name == 'REDDIT-BINARY': 46 | num_nodes = min(int(num_nodes / len(dataset) * 1.5), max_num_nodes) 47 | else: 48 | num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes) 49 | 50 | indices = [] 51 | for i, data in enumerate(dataset): 52 | if data.num_nodes <= num_nodes: 53 | indices.append(i) 54 | dataset = dataset[torch.Tensor(indices)] 55 | 56 | if dataset.transform is None: 57 | dataset.transform = T.ToDense(num_nodes) 58 | else: 59 | dataset.transform = T.Compose( 60 | [dataset.transform, T.ToDense(num_nodes)]) 61 | 62 | return dataset -------------------------------------------------------------------------------- /src/cuda/gat.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | std::vector gat_gar_edge_weight_cuda( 6 | torch::Tensor e_pre, 7 | torch::Tensor src_ptr, 8 | torch::Tensor tar_index, 9 | float negative_slope 10 | ); 11 | 12 | 13 | std::vector gat_gar_edge_weight( 14 | torch::Tensor e_pre, 15 | torch::Tensor src_ptr, 16 | torch::Tensor tar_index, 17 | float negative_slope 18 | ){ 19 | return gat_gar_edge_weight_cuda(e_pre, src_ptr, tar_index, negative_slope); 20 | } 21 | 22 | 23 | std::vector gat_gas_edge_weight_cuda( 24 | torch::Tensor e_pre, 25 | torch::Tensor src_index, 26 | torch::Tensor tar_index, 27 | float negative_slope 28 | ); 29 | 30 | 31 | std::vector gat_gas_edge_weight( 32 | torch::Tensor e_pre, 33 | torch::Tensor src_index, 34 | torch::Tensor tar_index, 35 | float negative_slope 36 | ){ 37 | return gat_gas_edge_weight_cuda(e_pre, src_index, tar_index, negative_slope); 38 | } 39 | 40 | 41 | std::vector gat_gar_edge_weight_b_cuda( 42 | torch::Tensor grad_alpha_self, 43 | torch::Tensor grad_alpha, 44 | torch::Tensor src_index, 45 | torch::Tensor tar_index, 46 | torch::Tensor mask_lrelu, 47 | torch::Tensor mask_lrelu_self, 48 | torch::Tensor e, 49 | torch::Tensor e_self, 50 | torch::Tensor e_sum, 51 | torch::Tensor alpha_self, 52 | torch::Tensor alpha 53 | ); 54 | 55 | std::vector gat_gar_edge_weight_b( 56 | torch::Tensor grad_alpha_self, 57 | torch::Tensor grad_alpha, 58 | torch::Tensor src_index, 59 | torch::Tensor tar_index, 60 | torch::Tensor mask_lrelu, 61 | torch::Tensor mask_lrelu_self, 62 | torch::Tensor e, 63 | torch::Tensor e_self, 64 | torch::Tensor e_sum, 65 | torch::Tensor alpha_self, 66 | torch::Tensor alpha 67 | ){ 68 | return gat_gar_edge_weight_b_cuda(grad_alpha_self, grad_alpha, src_index, tar_index, mask_lrelu, mask_lrelu_self, e, e_self, e_sum, alpha_self, alpha); 69 | } 70 | 71 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 72 | m.def("gat_gar_edge_weight", &gat_gar_edge_weight, "gat_gar_edge_weight"); 73 | m.def("gat_gas_edge_weight", &gat_gas_edge_weight, "gat_gas_edge_weight"); 74 | m.def("gat_gar_edge_weight_b", &gat_gar_edge_weight_b, "gat_gar_edge_weight_b"); 75 | } 76 | -------------------------------------------------------------------------------- /src/cuda/aggregate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | torch::Tensor fused_gar_f_cuda( 6 | torch::Tensor feature, 7 | torch::Tensor src_index, 8 | torch::Tensor tar_ptr, 9 | torch::Tensor edge_weight, 10 | torch::Tensor self_edge_weight 11 | ); 12 | 13 | 14 | torch::Tensor fused_gar_f( 15 | torch::Tensor feature, 16 | torch::Tensor src_index, 17 | torch::Tensor tar_ptr, 18 | torch::Tensor edge_weight, 19 | torch::Tensor self_edge_weight 20 | ){ 21 | return fused_gar_f_cuda(feature, src_index, tar_ptr, edge_weight, self_edge_weight); 22 | } 23 | 24 | 25 | std::vector fused_gar_b_cuda( 26 | torch::Tensor grad_out, 27 | torch::Tensor feature, 28 | torch::Tensor tar_index, 29 | torch::Tensor src_ptr, 30 | torch::Tensor edge_weight, 31 | torch::Tensor self_edge_weight, 32 | bool require_edge_weight 33 | ); 34 | 35 | 36 | std::vector fused_gar_b( 37 | torch::Tensor grad_out, 38 | torch::Tensor feature, 39 | torch::Tensor tar_index, 40 | torch::Tensor src_ptr, 41 | torch::Tensor edge_weight, 42 | torch::Tensor self_edge_weight, 43 | bool require_edge_weight 44 | ){ 45 | return fused_gar_b_cuda(grad_out, feature, tar_index, src_ptr, edge_weight, self_edge_weight, require_edge_weight); 46 | } 47 | 48 | 49 | torch::Tensor fused_gas_f_cuda( 50 | torch::Tensor feature, 51 | torch::Tensor src_index, 52 | torch::Tensor tar_index, 53 | torch::Tensor edge_weight, 54 | torch::Tensor self_edge_weight 55 | ); 56 | 57 | 58 | torch::Tensor fused_gas_f( 59 | torch::Tensor feature, 60 | torch::Tensor src_index, 61 | torch::Tensor tar_index, 62 | torch::Tensor edge_weight, 63 | torch::Tensor self_edge_weight 64 | ){ 65 | return fused_gas_f_cuda(feature, src_index, tar_index, edge_weight, self_edge_weight); 66 | } 67 | 68 | 69 | std::vector fused_gas_b_cuda( 70 | torch::Tensor grad_out, 71 | torch::Tensor feature, 72 | torch::Tensor src_index, 73 | torch::Tensor tar_index, 74 | torch::Tensor edge_weight, 75 | torch::Tensor self_edge_weight, 76 | bool require_edge_weight 77 | ); 78 | 79 | 80 | std::vector fused_gas_b( 81 | torch::Tensor grad_out, 82 | torch::Tensor feature, 83 | torch::Tensor src_index, 84 | torch::Tensor tar_index, 85 | torch::Tensor edge_weight, 86 | torch::Tensor self_edge_weight, 87 | bool require_edge_weight 88 | ){ 89 | return fused_gas_b_cuda(grad_out, feature, src_index, tar_index, edge_weight, self_edge_weight, require_edge_weight); 90 | } 91 | 92 | 93 | 94 | 95 | 96 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 97 | m.def("fused_gar_f", &fused_gar_f, "fused AGG GAR forward"); 98 | m.def("fused_gar_b", &fused_gar_b, "fused AGG GAR backward"); 99 | m.def("fused_gas_f", &fused_gas_f, "fused AGG GAS forward"); 100 | m.def("fused_gas_b", &fused_gas_b, "fused AGG GAS backward"); 101 | } 102 | -------------------------------------------------------------------------------- /src/cuda/format_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | std::vector csr2csc_cuda( 10 | torch::Tensor inPtr, 11 | torch::Tensor inInd, 12 | torch::Tensor inVal, 13 | int num_row 14 | ){ 15 | // initialize the output tensor 16 | auto outPtr = torch::zeros_like(inPtr); 17 | auto outInd = torch::empty_like(inInd); 18 | auto outVal = torch::empty_like(inVal); 19 | int nnz = inInd.size(0); 20 | 21 | // create cusparse handler 22 | cusparseHandle_t handle = 0; 23 | cusparseCreate(&handle); 24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 25 | cusparseSetStream(handle, stream); 26 | /* 27 | cusparseScsr2csc(handle, num_row, num_row, nnz, inVal.data(), inPtr.data(), 28 | inInd.data(), outVal.data(), outInd.data(), outPtr.data(), CUSPARSE_ACTION_SYMBOLIC, 29 | CUSPARSE_INDEX_BASE_ZERO); 30 | */ 31 | // for CUDA 10.2 32 | // Determine temporary device storage requirement 33 | void *d_temp_storage = NULL; 34 | size_t temp_storage_bytes = 0; 35 | 36 | AT_DISPATCH_FLOATING_TYPES(inVal.type(), "get temporary device storage requirement", ([&]{ 37 | cusparseCsr2cscEx2_bufferSize(handle, num_row, num_row, nnz, inVal.data(), inPtr.data(), 38 | inInd.data(), outVal.data(), outPtr.data(), outInd.data(), 39 | CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, 40 | &temp_storage_bytes 41 | ); 42 | })); 43 | 44 | 45 | 46 | // Allocate temporary storage 47 | cudaMalloc(&d_temp_storage, temp_storage_bytes); 48 | 49 | // Do the conversion 50 | AT_DISPATCH_FLOATING_TYPES(inVal.type(), "type convert", ([&]{ 51 | cusparseCsr2cscEx2(handle, num_row, num_row, nnz, inVal.data(), inPtr.data(), 52 | inInd.data(), outVal.data(), outPtr.data(), outInd.data(), 53 | CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, 54 | d_temp_storage 55 | ); 56 | })); 57 | 58 | cusparseDestroy(handle); 59 | cudaFree(d_temp_storage); 60 | return {outPtr, outInd, outVal}; 61 | } 62 | 63 | 64 | torch::Tensor coo2csr_cuda( 65 | torch::Tensor cooRowInd, 66 | int num_row 67 | ){ 68 | // initialize the output tensor 69 | auto options = torch::TensorOptions().dtype(torch::kInt32).device(cooRowInd.device()); 70 | auto csrRowPtr = torch::empty({num_row + 1, }, options); 71 | int nnz = cooRowInd.size(0); 72 | 73 | // create cusparse handler 74 | cusparseHandle_t handle = 0; 75 | cusparseCreate(&handle); 76 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 77 | cusparseSetStream(handle, stream); 78 | cusparseXcoo2csr(handle, cooRowInd.data(), nnz, num_row, csrRowPtr.data(), CUSPARSE_INDEX_BASE_ZERO); 79 | // cudaDeviceSynchronize(); 80 | 81 | cusparseDestroy(handle); 82 | return csrRowPtr; 83 | } -------------------------------------------------------------------------------- /fuseGNN/dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.transforms as T 3 | from sklearn.model_selection import StratifiedKFold 4 | from torch_geometric.data import DataLoader, DenseDataLoader 5 | from fuseGNN.dataloader import Citations, graph_kernel_dataset 6 | 7 | 8 | def k_fold(dataset, folds): 9 | """ 10 | Separate the dataset into k folds 11 | """ 12 | skf = StratifiedKFold(folds, shuffle=True, random_state=12345) 13 | 14 | test_indices, train_indices = [], [] 15 | for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y): 16 | # split: generate indices to split data 17 | # _: train indices, idx: test indices 18 | test_indices.append(torch.from_numpy(idx)) 19 | 20 | val_indices = [test_indices[i - 1] for i in range(folds)] 21 | # the validation set is just a permutated list of test set 22 | 23 | for i in range(folds): 24 | train_mask = torch.ones(len(dataset), dtype=torch.bool) 25 | train_mask[test_indices[i]] = 0 26 | train_mask[val_indices[i]] = 0 27 | train_indices.append(train_mask.nonzero().view(-1)) 28 | 29 | return train_indices, test_indices, val_indices # just returns three lists, each one has length fold. 30 | 31 | 32 | class DataProvider: 33 | def __init__(self, data, data_path, task, **kwargs): 34 | path = data_path + data 35 | 36 | if data in ['CiteSeer', 'Cora', 'PubMed']: 37 | """ Citation Networks: CiteSeer, Cora, and PubMed 38 | Description: 39 | these datasets contain sparse bag-of-words feature vectors for each document and a list of citation 40 | links between documents. The citation links are treated as undirected edges 41 | """ 42 | assert task == "node_class", "%s is a dataset for node classification" % data 43 | self.dataset = Citations(path, data, T.NormalizeFeatures()) 44 | self.data = self.dataset[0] 45 | # these citation networks are small enough fit into the GPUs, but they still have enough nodes for 46 | # training, validation, and test. 47 | 48 | elif data in ['COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'REDDIT-MULTI5K', 'PROTEINS', 'MUTAG', 49 | 'PTC', 'NCI1']: 50 | """ Benchmark Data Sets for Graph Kernels 51 | Description: 52 | these datasets are the benchmarks for graph classification, the detailed information is available 53 | at 54 | """ 55 | assert task == "graph_class", "%s is a dataset for graph classification" % data 56 | self.dataset = graph_kernel_dataset(data, path) 57 | self.train_ids, self.test_ids, self.val_ids = k_fold(self.dataset, kwargs['fold']) 58 | self.batch_size = kwargs['batch_size'] 59 | # these datasets are much smaller. For instance, MUTAG only has 188 different graphs. So unlike the above 60 | # node classification, here we need cross validation. 61 | else: 62 | raise NameError('unknown dataset') 63 | 64 | def get_cross_validation_loader(self, fold): 65 | (train_idx, test_idx, val_idx) = self.train_ids[fold], self.test_ids[fold], self.val_ids[fold] 66 | train_dataset = self.dataset[train_idx] 67 | test_dataset = self.dataset[test_idx] 68 | val_dataset = self.dataset[val_idx] 69 | if 'adj' in train_dataset[0]: 70 | train_loader = DenseDataLoader(train_dataset, self.batch_size, shuffle=True) 71 | val_loader = DenseDataLoader(val_dataset, self.batch_size, shuffle=False) 72 | test_loader = DenseDataLoader(test_dataset, self.batch_size, shuffle=False) 73 | else: 74 | train_loader = DataLoader(train_dataset, self.batch_size, shuffle=True) 75 | val_loader = DataLoader(val_dataset, self.batch_size, shuffle=False) 76 | test_loader = DataLoader(test_dataset, self.batch_size, shuffle=False) 77 | return train_loader, val_loader, test_loader 78 | -------------------------------------------------------------------------------- /fuseGNN/functional/aggregate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fgnn_agg 3 | from fuseGNN.functional.format import csr2csc 4 | 5 | 6 | # fused Aggregation phase with GAR model 7 | 8 | class fusedGARAggV1(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, feature, src_index, tar_ptr, edge_weight_f, 11 | self_edge_weight, tar_index, src_ptr, edge_weight_b): 12 | out = fgnn_agg.fused_gar_f(feature, src_index, tar_ptr, edge_weight_f, self_edge_weight) 13 | ctx.save_for_backward(tar_index, src_ptr, edge_weight_b, self_edge_weight) 14 | return out 15 | 16 | @staticmethod 17 | def backward(ctx, grad_out): 18 | tar_index, src_ptr, edge_weight_b, self_edge_weight = ctx.saved_tensors 19 | grad_features, _, _2 = fgnn_agg.fused_gar_b(grad_out, tar_index, tar_index, src_ptr, 20 | edge_weight_b, self_edge_weight, False) 21 | return grad_features, None, None, None, None, None, None, None, None 22 | 23 | 24 | class fusedGARAggV2(torch.autograd.Function): 25 | @staticmethod 26 | def forward(ctx, feature, src_index, tar_ptr, edge_weight_f, 27 | self_edge_weight, tar_index, src_ptr, edge_weight_b): 28 | out = fgnn_agg.fused_gar_f(feature, src_index, tar_ptr, edge_weight_f, self_edge_weight) 29 | ctx.save_for_backward(tar_index, src_ptr, edge_weight_b, self_edge_weight, feature) 30 | return out 31 | 32 | @staticmethod 33 | def backward(ctx, grad_out): 34 | tar_index, src_ptr, edge_weight_b, self_edge_weight, feature = ctx.saved_tensors 35 | grad_features, grad_edge_weight, grad_weight_self = fgnn_agg.fused_gar_b(grad_out, feature, tar_index, src_ptr, 36 | edge_weight_b, self_edge_weight, 37 | True) 38 | return grad_features, None, None, None, grad_weight_self, None, None, grad_edge_weight, None 39 | 40 | 41 | def fused_gar_agg(feature, src_index, tar_ptr, edge_weight_f, 42 | self_edge_weight, tar_index, src_ptr, edge_weight_b, 43 | require_edge_weight=False): 44 | if require_edge_weight: 45 | return fusedGARAggV2.apply(feature, src_index, tar_ptr, edge_weight_f, 46 | self_edge_weight, tar_index, src_ptr, edge_weight_b) 47 | else: 48 | return fusedGARAggV1.apply(feature, src_index, tar_ptr, edge_weight_f, 49 | self_edge_weight, tar_index, src_ptr, edge_weight_b) 50 | 51 | class fusedGASAggV1(torch.autograd.Function): 52 | @staticmethod 53 | def forward(ctx, feature, src_index, tar_index, edge_weight, self_edge_weight): 54 | ctx.save_for_backward(src_index, tar_index, edge_weight, self_edge_weight) 55 | out = fgnn_agg.fused_gas_f(feature, src_index, tar_index, edge_weight, self_edge_weight) 56 | return out 57 | 58 | @staticmethod 59 | def backward(ctx, grad_out): 60 | src_index, tar_index, edge_weight, self_edge_weight = ctx.saved_tensors 61 | grad_features, _, _2 = fgnn_agg.fused_gas_b(grad_out, src_index, src_index, tar_index, 62 | edge_weight, self_edge_weight, False) 63 | return grad_features, None, None, None, None 64 | 65 | 66 | class fusedGASAggV2(torch.autograd.Function): 67 | @staticmethod 68 | def forward(ctx, feature, src_index, tar_index, edge_weight, self_edge_weight): 69 | ctx.save_for_backward(src_index, tar_index, edge_weight, self_edge_weight, feature) 70 | out = fgnn_agg.fused_gas_f(feature, src_index, tar_index, edge_weight, self_edge_weight) 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_out): 75 | src_index, tar_index, edge_weight, self_edge_weight, feature = ctx.saved_tensors 76 | grad_features, grad_edge_weight, grad_weight_self = fgnn_agg.fused_gas_b(grad_out, feature, src_index, 77 | tar_index, edge_weight, 78 | self_edge_weight, True) 79 | return grad_features, None, None, grad_edge_weight, grad_weight_self 80 | 81 | 82 | 83 | def fused_gas_agg(feature, src_index, tar_index, edge_weight, self_edge_weight, require_edge_weight=False): 84 | if require_edge_weight: 85 | return fusedGASAggV2.apply(feature, src_index, tar_index, edge_weight, self_edge_weight) 86 | else: 87 | return fusedGASAggV1.apply(feature, src_index, tar_index, edge_weight, self_edge_weight) 88 | -------------------------------------------------------------------------------- /fuseGNN/functional/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_scatter 3 | import torch.nn.functional as F 4 | import fgnn_gat 5 | from fuseGNN.functional.format import Coo2Csr 6 | import torch_scatter 7 | 8 | 9 | class GATGAREdgeWeight(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, e_pre, src_ptr, tar_index, src_index, negative_slope): 12 | alpha, alpha_self, mask_lrelu, mask_lrelu_self, e_sum, e, e_self = fgnn_gat.gat_gar_edge_weight(e_pre, src_ptr, tar_index, negative_slope) 13 | ctx.save_for_backward(src_index, tar_index, mask_lrelu, mask_lrelu_self, e, e_self, e_sum, alpha_self, alpha) 14 | return alpha_self, alpha 15 | 16 | @staticmethod 17 | def backward(ctx, grad_alpha_self, grad_alpha): 18 | src_index, tar_index, mask_lrelu, mask_lrelu_self, e, e_self, e_sum, alpha_self, alpha = ctx.saved_tensors 19 | 20 | grad_e_sum, grad_e_pre = fgnn_gat.gat_gar_edge_weight_b(grad_alpha_self, grad_alpha, src_index, tar_index, mask_lrelu, mask_lrelu_self, 21 | e, e_self, e_sum, alpha_self, alpha) 22 | return grad_e_pre, None, None, None, None, None 23 | 24 | gat_gar_edge_weight = GATGAREdgeWeight.apply 25 | 26 | 27 | class GATGASEdgeWeight(torch.autograd.Function): 28 | @staticmethod 29 | def forward(ctx, e_pre, src_index, tar_index, negative_slope): 30 | alpha, alpha_self, mask_lrelu, mask_lrelu_self, e_sum, e, e_self = fgnn_gat.gat_gas_edge_weight(e_pre, src_index, tar_index, negative_slope) 31 | ctx.save_for_backward(src_index, tar_index, mask_lrelu, mask_lrelu_self, e, e_self, e_sum, alpha_self, alpha) 32 | return alpha_self, alpha 33 | 34 | @staticmethod 35 | def backward(ctx, grad_alpha_self, grad_alpha): 36 | src_index, tar_index, mask_lrelu, mask_lrelu_self, e, e_self, e_sum, alpha_self, alpha = ctx.saved_tensors 37 | 38 | grad_e_sum, grad_e_pre = fgnn_gat.gat_gar_edge_weight_b(grad_alpha_self, grad_alpha, src_index, tar_index, mask_lrelu, mask_lrelu_self, 39 | e, e_self, e_sum, alpha_self, alpha) 40 | return grad_e_pre, None, None, None, None 41 | 42 | gat_gas_edge_weight = GATGASEdgeWeight.apply 43 | 44 | 45 | 46 | class refGATEdgeWeight(torch.autograd.Function): 47 | @staticmethod 48 | def forward(ctx, e_pre, src_ptr, tar_index, src_index, negative_slope): 49 | e_pre_src, e_pre_tar = torch.chunk(e_pre, chunks=2, dim=1) 50 | e_pre_src_expand = torch.index_select(e_pre_src, dim=0, index=src_index) 51 | e_pre_tar_expand = torch.index_select(e_pre_tar, dim=0, index=tar_index) 52 | e = e_pre_src_expand + e_pre_tar_expand 53 | e_self = e_pre_src + e_pre_tar 54 | mask_lrelu = F.leaky_relu(e, negative_slope=negative_slope) / e 55 | mask_lrelu_self = F.leaky_relu(e_self, negative_slope=negative_slope) / e_self 56 | e *= mask_lrelu 57 | e_self *= mask_lrelu_self 58 | e = torch.exp(e) 59 | e_self = torch.exp(e_self) 60 | e_sum = torch_scatter.scatter_add(src=e, index=tar_index, dim=0, dim_size=e_self.size(0)) 61 | e_sum += e_self 62 | alpha_self = e_self / e_sum 63 | e_sum_ext = torch.index_select(e_sum, dim=0, index=tar_index) 64 | alpha = e / e_sum_ext 65 | 66 | ctx.save_for_backward(src_index, tar_index, mask_lrelu, mask_lrelu_self, e, e_self, e_sum, alpha_self, alpha) 67 | return alpha_self, alpha 68 | 69 | @staticmethod 70 | def backward(ctx, grad_alpha_self, grad_alpha): 71 | src_index, tar_index, mask_lrelu, mask_lrelu_self, e, e_self, e_sum, alpha_self, alpha = ctx.saved_tensors 72 | 73 | src_index = src_index.to(torch.int64) 74 | tar_index = tar_index.to(torch.int64) 75 | 76 | e_sum_ext = torch.index_select(e_sum, dim=0, index=tar_index) 77 | grad_e_sum = -grad_alpha * alpha / e_sum_ext 78 | grad_e = grad_alpha / e_sum_ext 79 | 80 | # e_sum_ext = torch.index_select(e_sum, dim=0, index=tar_index) 81 | grad_e_sum = torch_scatter.scatter_add(src=grad_e_sum, index=tar_index, dim=0, dim_size=alpha_self.size(0)) 82 | 83 | grad_e_self = grad_alpha_self / e_sum 84 | grad_e_sum -= grad_alpha_self * alpha_self / e_sum 85 | 86 | grad_e_self += grad_e_sum 87 | grad_e += torch.index_select(grad_e_sum, dim=0, index=tar_index) 88 | grad_e_self *= e_self 89 | grad_e *= e 90 | 91 | grad_e_self *= mask_lrelu_self 92 | grad_e *= mask_lrelu 93 | 94 | grad_e_pre_src = grad_e_self + torch_scatter.scatter_add(src=grad_e, index=src_index, dim=0, dim_size=grad_e_self.size(0)) 95 | grad_e_pre_tar = grad_e_self + torch_scatter.scatter_add(src=grad_e, index=tar_index, dim=0, dim_size=grad_e_self.size(0)) 96 | 97 | grad_e_pre = torch.cat((grad_e_pre_src, grad_e_pre_tar), dim=1) 98 | return grad_e_pre, None, None, None, None, None -------------------------------------------------------------------------------- /fuseGNN/modules/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from fuseGNN.convs import geoGCNConv, refGCNConv, garGCNConv, gasGCNConv 4 | from fuseGNN.utils import LrSchedular 5 | 6 | 7 | modules = { 8 | 'geo': geoGCNConv, 9 | 'ref': refGCNConv, 10 | 'gar': garGCNConv, 11 | 'gas': gasGCNConv 12 | } 13 | 14 | 15 | class GCN(torch.nn.Module): 16 | """ 17 | The graph convolutional operator from the `"Semi-supervised Classification with Graph Convolutional Networks" 18 | 19 | A two-layer GCN model. 20 | The model is trained for 200 epochs with learning rate 0.01 and early stopped with a window size of 10 21 | (the validation loss doesn't decrease for 10 consecutive epochs) 22 | """ 23 | def __init__(self, num_features, hidden, num_classes, cached=True, drop_rate=0.5, mode='geo', flow='source_to_target'): 24 | """ 25 | :param num_features: the length of input features 26 | :param hidden: the length of hidden layer 27 | :param num_classes: the number of classes 28 | :param cached: If True, the layer will cache the computation on first execution, and will use the 29 | cached version for further executions. So it should be only true in transductive learning scenarios 30 | """ 31 | super(GCN, self).__init__() 32 | self.GCNConv = modules[mode] 33 | self.mode = mode 34 | self.conv1 = self.GCNConv(in_channels=num_features, out_channels=hidden, cached=cached, flow=flow) 35 | self.conv2 = self.GCNConv(in_channels=hidden, out_channels=num_classes, cached=cached, flow=flow) 36 | 37 | self.reg_params = self.conv1.parameters() 38 | self.non_reg_params = self.conv2.parameters() 39 | self.drop_rate = drop_rate 40 | 41 | def forward(self, data): 42 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 43 | x = F.relu(self.conv1(x, edge_index, edge_weight)) 44 | x = F.dropout(input=x, p=self.drop_rate, training=self.training) 45 | if self.mode == 'gar': 46 | x = self.conv2(x=x, edge_weight=self.conv1.cached_edge_weight_f, self_edge_weight=self.conv1.cached_self_edge_weight, 47 | tar_ptr=self.conv1.cached_tar_ptr, src_index=self.conv1.cached_src_index, 48 | src_ptr=self.conv1.cached_src_ptr, tar_index=self.conv1.cached_tar_index, 49 | edge_weight_b=self.conv1.cached_edge_weight_b) 50 | elif self.mode == 'gas': 51 | x = self.conv2(x=x, edge_weight=self.conv1.cached_edge_weight, src_index=self.conv1.cached_src_index, 52 | tar_index=self.conv1.cached_tar_index, self_edge_weight=self.conv1.cached_self_edge_weight) 53 | else: 54 | x = self.conv2(x, edge_index, edge_weight) 55 | return F.log_softmax(x, dim=1) 56 | 57 | 58 | """ 59 | Training configurations from 60 | "Semi-supervised Classification with Graph Convolutional Networks" 61 | The lines below as cited from the paper: 62 | > we train a two-layer GCN as ... 63 | > For the citation network datasets, we optimize hyperparameters on Cora only 64 | and use the same set of parameters for Citeseer and Pubmed. 65 | > We train all models for a maximum of 200 epochs (training iterations) 66 | using Adam with a learning rate of 0.01 67 | > We used the following sets of hyperparameters for Citeseer, Cora and Pubmed: 68 | 0.5 (dropout rate), 5 · 10−4 (L2 regularization) and 16 (number of hid- den units); 69 | and for NELL: 0.1 (dropout rate), 1 · 10−5 (L2 regularization) and 64 (number of hidden units). 70 | """ 71 | gcn_config = { 72 | 'CiteSeer': { 73 | 'drop_rate': 0.5, 74 | 'weight_decay': 4e-4, 75 | 'hidden': 16, 76 | 'lr': 0.01, 77 | 'lr_schedular': LrSchedular(init_lr=0.01, mode='constant'), 78 | 'fold': 1, 79 | }, 80 | 'Cora': { 81 | 'drop_rate': 0.5, 82 | 'weight_decay': 4e-4, 83 | 'hidden': 16, 84 | 'lr': 0.01, 85 | 'lr_schedular': LrSchedular(init_lr=0.01, mode='constant'), 86 | 'fold': 1, 87 | }, 88 | 'PubMed': { 89 | 'drop_rate': 0.5, 90 | 'weight_decay': 4e-4, 91 | 'hidden': 16, 92 | 'lr': 0.01, 93 | 'lr_schedular': LrSchedular(init_lr=0.01, mode='constant'), 94 | 'fold': 1, 95 | }, 96 | 'Nell': { 97 | 'drop_rate': 0.1, 98 | 'weight_decay': 1e-5, 99 | 'hidden': 64, 100 | 'lr': 0.01, 101 | 'lr_schedular': LrSchedular(init_lr=0.01, mode='constant'), 102 | 'fold': 1, 103 | }, 104 | 'Reddit': { 105 | 'drop_rate': 0.5, 106 | 'weight_decay': 1e-5, 107 | 'hidden': 128, 108 | 'lr': 0.01, 109 | 'lr_schedular': LrSchedular(init_lr=0.01, mode='constant'), 110 | 'fold': 1, 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /training_main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training GCN/GAT on CiteSeer, Cora, PubMed, and Reddit datasets 3 | """ 4 | import argparse 5 | import os 6 | import torch 7 | from fuseGNN.modules import GCN, gcn_config, GAT, gat_config 8 | from fuseGNN.dataloader import Citations, Reddit 9 | from fuseGNN.utils import Logger 10 | import torch_geometric.transforms as T 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | import sys 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | # dataset config 18 | parser.add_argument('--data', choices=['CiteSeer', 'Cora', 'PubMed', 'Reddit'], help='dataset name') 19 | parser.add_argument('--model', choices=['GCN', 'GAT'], help='GCN model') 20 | parser.add_argument('--data_path', type=str, default='/raid/datasets/GNN/', help='the path to datasets') 21 | # training config 22 | parser.add_argument('--max_iter', type=int, default=200, help='maximum training iterations') 23 | parser.add_argument('--gpus', type=str, default='0', help='gpu to use') 24 | # logging 25 | parser.add_argument('--log_dir', type=str, default='./log/', help='the path to the logs') 26 | # model configures 27 | parser.add_argument('--mode', choices=['geo', 'ref', 'gas', 'gar'], default='ref', help='run which mode') 28 | parser.add_argument('--flow', choices=['target_to_source', 'source_to_target'], default='target_to_source') 29 | 30 | args = parser.parse_args() 31 | 32 | 33 | # configure CUDA 34 | os.system('export CUDA_VISIBLE_DEVICES=' + args.gpus) 35 | assert torch.cuda.is_available(), 'CUDA is not available' 36 | device = torch.device('cuda') 37 | 38 | 39 | # configure dataset 40 | path = args.data_path + args.data 41 | try: 42 | if args.data in ['Cora', 'CiteSeer', 'PubMed']: 43 | dataset = Citations(path, args.data, T.NormalizeFeatures()) 44 | elif args.data in ['Reddit']: 45 | dataset = Reddit(path) 46 | data = dataset[0] 47 | """ 48 | The data contains 49 | edge_index=[2, N(e)], test_mask=[N(v)], train_mask=[N(v)], val_mask=[N(v)], x=[N(v), dim], y=[N(v)], deg=[N(v)] 50 | The deg is very imbalanced, e.g. [1, 168] for Cora, however, I can still try to ignore it at this very begining 51 | """ 52 | except: 53 | print('The dataset does not exist or is not supported.') 54 | sys.exit() 55 | 56 | data.train_mask = data.train_mask.to(torch.bool) 57 | data.val_mask = data.val_mask.to(torch.bool) 58 | data.test_mask = data.test_mask.to(torch.bool) 59 | 60 | if args.model == 'GCN': 61 | config = gcn_config[args.data] 62 | model = GCN(num_features=dataset.num_features, hidden=config['hidden'], 63 | num_classes=dataset.num_classes, cached=True, drop_rate=config['drop_rate'], mode=args.mode, flow=args.flow) 64 | elif args.model == 'GAT': 65 | config = gat_config[args.data] 66 | model = GAT(num_features=dataset.num_features, hidden=config['hidden'], 67 | num_classes=dataset.num_classes, heads=config['head'], drop_rate=config['drop_rate'], mode=args.mode) 68 | 69 | logger = Logger(model_=args.model + '_' + args.mode, data_=args.data, log_dir=args.log_dir) 70 | 71 | optimizer = torch.optim.Adam([ 72 | dict(params=model.reg_params, weight_decay=config['weight_decay']), 73 | dict(params=model.non_reg_params, weight_decay=0.) 74 | ], lr=config['lr']) 75 | 76 | # training 77 | class NodeClassifier: 78 | def __init__(self, data, model): 79 | self.data = data.to(device) 80 | self.model = model.to(device) 81 | 82 | def train(self): 83 | self.model.train() 84 | optimizer.zero_grad() 85 | loss = F.nll_loss(self.model(self.data)[self.data.train_mask], self.data.y[self.data.train_mask]) 86 | loss.backward() 87 | optimizer.step() 88 | return loss.item() 89 | 90 | def test(self): 91 | self.model.eval() 92 | logits, accs = self.model(self.data), [] 93 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'): 94 | pred = logits[mask].max(1)[1] 95 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() 96 | accs.append(acc) 97 | return accs 98 | 99 | def next_fold(self, fold_): 100 | pass 101 | 102 | classifier = NodeClassifier(data, model) 103 | 104 | pbar = tqdm(range(args.max_iter)) 105 | best_test_acc = 0. 106 | for epoch in pbar: 107 | config['lr_schedular'].update(epoch, optimizer) 108 | train_loss = classifier.train() 109 | train_acc, val_acc, test_acc = classifier.test() 110 | logger.add_scalar('train_loss', train_loss, epoch) 111 | logger.add_scalar('train_acc', train_acc, epoch) 112 | logger.add_scalar('val_acc', val_acc, epoch) 113 | logger.add_scalar('test_acc', test_acc, epoch) 114 | logger.log['epoch'].append(epoch) 115 | if test_acc > best_test_acc: 116 | best_test_acc = test_acc 117 | pbar.set_description('Train ACC: %.3f | Test ACC: %.3f | Loss: %.3f' % (train_acc, best_test_acc, train_loss)) 118 | 119 | logger.write() 120 | """ 121 | summary = torch.cuda.memory_summary(device=device) 122 | print(summary) 123 | mem_stat = torch.cuda.memory_stats(device=device) 124 | print("Peak Memory: %.3f MB" % float(float(mem_stat['allocated_bytes.all.peak']) / 1024. / 1024.)) 125 | """ -------------------------------------------------------------------------------- /fuseGNN/modules/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from fuseGNN.convs import geoGATConv, refGATConv, garGATConv, gasGATConv 4 | from fuseGNN.utils import LrSchedular 5 | 6 | 7 | class GAT(torch.nn.Module): 8 | """ 9 | The graph convolutional operator from the `"Graph Attention Networks" 10 | 11 | A two-layer GAT model. 12 | The model is trained for 200 epochs with learning rate 0.01 and early stopped with a window size of 10 13 | (the validation loss doesn't decrease for 10 consecutive epochs) 14 | """ 15 | def __init__(self, num_features, hidden, heads, num_classes, drop_rate=0.5, mode='geo', flow='source_to_target'): 16 | """ 17 | :param num_features: the length of input features 18 | :param hidden: the length of hidden layer 19 | :param num_classes: the number of classes 20 | :param cached: If True, the layer will cache the computation on first execution, and will use the 21 | cached version for further executions. So it should be only true in transductive learning scenarios 22 | """ 23 | super(GAT, self).__init__() 24 | if mode == 'ref': 25 | self.conv1 = refGATConv(in_channels=num_features, out_channels=int(hidden/heads), heads=heads, dropout=drop_rate) 26 | self.conv2 = refGATConv(in_channels=hidden, out_channels=num_classes, concat=False, dropout=drop_rate) 27 | elif mode == 'gar': 28 | self.conv1 = garGATConv(in_channels=num_features, out_channels=int(hidden/heads), heads=heads, dropout=drop_rate, cached=False, return_mask=False) 29 | self.conv2 = garGATConv(in_channels=hidden, out_channels=num_classes, concat=False, dropout=drop_rate, cached=False, return_mask=False) 30 | elif mode == 'gas': 31 | self.conv1 = gasGATConv(in_channels=num_features, out_channels=int(hidden/heads), heads=heads, dropout=drop_rate, cached=False, return_mask=False) 32 | self.conv2 = gasGATConv(in_channels=hidden, out_channels=num_classes, concat=False, dropout=drop_rate, cached=False, return_mask=False) 33 | else: 34 | self.conv1 = geoGATConv(in_channels=num_features, out_channels=int(hidden/heads), heads=heads, dropout=drop_rate) 35 | self.conv2 = geoGATConv(in_channels=hidden, out_channels=num_classes, concat=False, dropout=drop_rate) 36 | self.reg_params = self.conv1.parameters() 37 | self.non_reg_params = self.conv2.parameters() 38 | self.mode = mode 39 | self.drop_rate = drop_rate 40 | 41 | def forward(self, data): 42 | x, edge_index = data.x, data.edge_index 43 | x = F.dropout(data.x, p=self.drop_rate, training=self.training) 44 | x = F.elu(self.conv1(x, edge_index)) 45 | x = F.dropout(input=x, p=self.drop_rate, training=self.training) 46 | if self.mode == 'gar': 47 | x = self.conv2(x=x, tar_index_b=self.conv1.cached_tar_index_b, src_index_b=self.conv1.cached_src_index_b, src_ptr=self.conv1.cached_src_ptr) 48 | elif self.mode == 'gas': 49 | x = self.conv2(x=x, src_index=self.conv1.cached_src_index, tar_index=self.conv1.cached_tar_index) 50 | else: 51 | x = self.conv2(x, edge_index) 52 | return F.log_softmax(x, dim=1) 53 | 54 | 55 | """ 56 | Training configurations from 57 | "Graph Attention Networks" 58 | The lines below as cited from the paper: 59 | > we apply a two-layer GAT model ... 60 | > The first layer consists of K = 8 attention heads computing F'= 8 features each (for a total of 64 features), 61 | followed by an exponential linear unit (ELU) 62 | > The second layer is used for classification: a single attention head that computes C features 63 | (where C is the number of classes), followed by a softmax activation. 64 | > During training, we apply L2 regulariza- tion with λ = 0.0005. 65 | > Furthermore, dropout with p = 0.6 is applied to both layers’ inputs, 66 | > normalized attention coefficients 67 | 68 | > For PubMed: we have applied K = 8 output attention heads (instead of one), and strengthened the L2 regularization to λ = 0.001. 69 | Otherwise, the architecture matches the one used for Cora and Citeseer. 70 | 71 | > Both models are initialized using Glorot initialization and trained to minimize cross-entropy on the training nodes 72 | using the Adam SGD optimizer with an initial learning rate of 0.01 for Pubmed, and 0.005 for all other datasets. 73 | > 74 | > We train all models for a maximum of 200 epochs (training iterations) 75 | using Adam with a learning rate of 0.01 76 | > with a patience of 100 epochs 77 | """ 78 | gat_config = { 79 | 'CiteSeer': { 80 | 'drop_rate': 0.6, 81 | 'weight_decay': 5e-4, 82 | 'hidden': 64, 83 | 'lr': 0.005, 84 | 'head': 1, 85 | 'lr_schedular': LrSchedular(init_lr=0.005, mode='constant'), 86 | 'fold': 1, 87 | }, 88 | 'Cora': { 89 | 'drop_rate': 0.6, 90 | 'weight_decay': 5e-4, 91 | 'hidden': 64, 92 | 'lr': 0.005, 93 | 'head': 1, 94 | 'lr_schedular': LrSchedular(init_lr=0.005, mode='constant'), 95 | 'fold': 1, 96 | }, 97 | 'PubMed': { 98 | 'drop_rate': 0.6, 99 | 'weight_decay': 1e-3, 100 | 'hidden': 64, 101 | 'lr': 0.01, 102 | 'head': 1, 103 | 'lr_schedular': LrSchedular(init_lr=0.01, mode='constant'), 104 | 'fold': 1, 105 | }, 106 | 'Reddit': { 107 | 'drop_rate': 0.5, 108 | 'weight_decay': 1e-3, 109 | 'hidden': 128, 110 | 'lr': 0.001, 111 | 'head': 1, 112 | 'lr_schedular': LrSchedular(init_lr=0.01, mode='constant'), 113 | 'fold': 1, 114 | }, 115 | } 116 | -------------------------------------------------------------------------------- /fuseGNN/testbench/gcn_conv_tb.py: -------------------------------------------------------------------------------- 1 | """ 2 | A testbench that verifies training kernels (both forward pass and backward pass) 3 | """ 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch_scatter 8 | import sys 9 | import torch_geometric.transforms as T 10 | from torch_geometric.utils import add_remaining_self_loops 11 | from fuseGNN.dataloader import Citations 12 | from fuseGNN.convs import geoGCNConv, refGCNConv, garGCNConv, gasGCNConv 13 | from torch_geometric.utils import degree 14 | from tqdm import tqdm 15 | import numpy as np 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data_path', type=str, default='/raid/datasets/GNN/') 19 | args = parser.parse_args() 20 | 21 | datasets = ['Cora', 'CiteSeer', 'PubMed'] 22 | edge_permutes = [True, False] 23 | edge_weights = [True, False] 24 | flows = ['target_to_source', 'source_to_target'] 25 | 26 | # configure CUDA 27 | assert torch.cuda.is_available(), "CUDA is not available" 28 | device = torch.device('cuda') 29 | 30 | 31 | class CastratedGCN(torch.nn.Module): 32 | """ 33 | GCN model for profiling 34 | """ 35 | def __init__(self, num_features, data=None, hidden=256, num_classes=10, cached=True, fused=False, flow='target_to_source'): 36 | super(CastratedGCN, self).__init__() 37 | self.data = data 38 | if fused: 39 | self.conv1 = gasGCNConv(in_channels=num_features, out_channels=hidden, cached=cached, 40 | flow=flow, bias=False) 41 | else: 42 | self.conv1 = geoGCNConv(in_channels=num_features, out_channels=hidden, cached=cached, 43 | flow=flow, bias=False) 44 | 45 | 46 | def forward(self, x, data=None, edge_weight=None): 47 | if data is None: 48 | data = self.data 49 | x = self.conv1(x, data.edge_index, edge_weight) 50 | return x 51 | 52 | 53 | def single_test(data_name, edge_permute, ew, flow): 54 | # Configure dataset 55 | path = args.data_path + data_name 56 | try: 57 | dataset = Citations(path, data_name, T.NormalizeFeatures()) 58 | data = dataset[0] 59 | # todo: preprocess the degree vector 60 | """ 61 | The data contains 62 | edge_index=[2, N(e)], test_mask=[N(v)], train_mask=[N(v)], val_mask=[N(v)], x=[N(v), dim], y=[N(v)], deg=[N(v)] 63 | The deg is very imbalanced, e.g. [1, 168] for Cora, however, I can still try to ignore it at this very begining 64 | """ 65 | except: 66 | print('The dataset does not exist or is not supported.') 67 | sys.exit() 68 | data.to(device) 69 | if ew: 70 | # edge_weight = torch.randperm(data.edge_index.size(1)).to(torch.float32).to('cuda') + 1 71 | edge_weight = torch.abs_(torch.randn(size=(data.edge_index.size(1),), dtype=torch.float32, device=data.edge_index.device)) + 1 72 | # print(edge_weight.max()) 73 | else: 74 | edge_weight = None 75 | if edge_permute: 76 | r = torch.randperm(data.edge_index.size(1)) 77 | data.edge_index = data.edge_index[:, r] 78 | if ew: 79 | edge_weight = edge_weight[r] 80 | hidden = np.random.randint(low=5, high=1024) 81 | 82 | model = CastratedGCN(num_features=dataset.num_features, data=data, hidden=hidden, num_classes=dataset.num_classes, cached=False, fused=False, flow=flow) 83 | model.to(device) 84 | 85 | fmodel = CastratedGCN(num_features=dataset.num_features, data=data, hidden=hidden, num_classes=dataset.num_classes, cached=False, fused=True, flow=flow) 86 | fmodel.conv1.dense.weight = torch.nn.Parameter(model.conv1.weight.t()) 87 | # fmodel.conv2.weight = model.conv2.weight 88 | # fmodel.conv1.dense.bias = torch.nn.Parameter(model.conv1.bias) 89 | # fmodel.conv1.dense.weight = model.conv1.dense.weight 90 | # fmodel.conv1.dense.bias = model.conv1.dense.bias 91 | fmodel.to(device) 92 | 93 | model.train() 94 | fmodel.train() 95 | 96 | x_f = data.x.clone().requires_grad_(True) 97 | x = data.x.clone().requires_grad_(True) 98 | 99 | f_res = fmodel(x=x_f, edge_weight=edge_weight) 100 | ref = model(x=x, edge_weight=edge_weight) 101 | 102 | grad = torch.rand_like(f_res) 103 | 104 | f_res.backward(grad) 105 | ref.backward(grad) 106 | 107 | grad_x_f = x_f.grad 108 | grad_x = x.grad 109 | 110 | max_error = torch.max((f_res - ref)).item() 111 | max_error_b = torch.max((grad_x_f - grad_x)).item() 112 | passed = True 113 | if max_error > 1e-5 or np.isnan(max_error): 114 | print("[Forward] on %s, edge_permute %r, edge_weight %r, hidden size %d, flow %s" % (data_name, edge_permute, ew, hidden, flow)) 115 | print("there are %d different entries in overall %d entires. The maximum difference is %f" % 116 | (torch.nonzero(f_res - ref).size(0), f_res.numel(), max_error)) 117 | passed = False 118 | if max_error_b > 1e-5 or np.isnan(max_error_b): 119 | print("[Backward] on %s, edge_permute %r, edge_weight %r, hidden size %d, flow %s" % (data_name, edge_permute, ew, hidden, flow)) 120 | print("there are %d different entries in overall %d entires. The maximum difference is %f" % 121 | (torch.nonzero(grad_x_f - grad_x).size(0), grad_x_f.numel(), max_error_b)) 122 | passed = False 123 | return passed 124 | 125 | 126 | num_exp = 0 127 | num_pass = 0 128 | 129 | for ds in datasets: 130 | for ep in edge_permutes: 131 | for ew in edge_weights: 132 | for f in flows: 133 | for i in range(3): 134 | num_exp += 1 135 | if single_test(ds, ep, ew, f): 136 | num_pass += 1 137 | 138 | print("%d out of %d tests passed" % (num_pass, num_exp)) -------------------------------------------------------------------------------- /fuseGNN/testbench/gat_conv_tb.py: -------------------------------------------------------------------------------- 1 | """ 2 | A testbench that verifies both forward and backward kernels of GAT 3 | """ 4 | import argparse 5 | import torch 6 | import torch.nn.functional as F 7 | import torch_scatter 8 | import sys 9 | import torch_geometric.transforms as T 10 | from torch_geometric.utils import add_remaining_self_loops 11 | # import gcnlib 12 | from fuseGNN.dataloader import Citations 13 | from fuseGNN.convs import garGATConv, refGATConv, gasGATConv, geoGATConv 14 | from torch_geometric.utils import degree 15 | from tqdm import tqdm 16 | import numpy as np 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--data_path', type=str, default='/raid/datasets/GNN/') 20 | args = parser.parse_args() 21 | 22 | datasets = ['Cora', 'CiteSeer', 'PubMed'] 23 | edge_permutes = [True, False] 24 | flows = ['target_to_source', 'source_to_target'] 25 | 26 | # configure CUDA 27 | assert torch.cuda.is_available(), "CUDA is not available" 28 | device = torch.device('cuda') 29 | 30 | 31 | class CastratedGAT(torch.nn.Module): 32 | """ 33 | GCN model for profiling 34 | """ 35 | def __init__(self, num_features, data=None, hidden=256, heads=8, fused=False, flow='target_to_source'): 36 | super(CastratedGAT, self).__init__() 37 | self.data = data 38 | if fused: 39 | self.conv1 = garGATConv(in_channels=num_features, out_channels=int(hidden/heads), heads=heads, dropout=0.6, flow=flow) 40 | else: 41 | self.conv1 = refGATConv(in_channels=num_features, out_channels=int(hidden/heads), heads=heads, dropout=0.6, flow=flow, return_mask=True) 42 | 43 | 44 | def forward(self, x, data=None, dp_mask=None, dp_mask_self=None): 45 | if data is None: 46 | data = self.data 47 | return self.conv1(x, data.edge_index, dp_mask=dp_mask, dp_mask_self=dp_mask_self) 48 | 49 | 50 | def single_test(data_name, edge_permute, flow): 51 | # Configure dataset 52 | path = args.data_path + data_name 53 | try: 54 | dataset = Citations(path, data_name, T.NormalizeFeatures()) 55 | data = dataset[0] 56 | # todo: preprocess the degree vector 57 | """ 58 | The data contains 59 | edge_index=[2, N(e)], test_mask=[N(v)], train_mask=[N(v)], val_mask=[N(v)], x=[N(v), dim], y=[N(v)], deg=[N(v)] 60 | The deg is very imbalanced, e.g. [1, 168] for Cora, however, I can still try to ignore it at this very begining 61 | """ 62 | except: 63 | print('The dataset does not exist or is not supported.') 64 | sys.exit() 65 | data.to(device) 66 | if edge_permute: 67 | r = torch.randperm(data.edge_index.size(1)) 68 | data.edge_index = data.edge_index[:, r] 69 | hidden = np.random.randint(low=1, high=128) 70 | 71 | model = CastratedGAT(num_features=dataset.num_features, data=data, hidden=hidden * 8, heads=8, fused=False, flow=flow) 72 | model.to(device) 73 | model.train() 74 | 75 | fmodel = CastratedGAT(num_features=dataset.num_features, data=data, hidden=hidden * 8, heads=8, fused=True, flow=flow) 76 | fmodel.to(device) 77 | fmodel.train() 78 | 79 | fmodel.conv1.dense.weight = model.conv1.dense.weight 80 | fmodel.conv1.dense.bias = model.conv1.dense.bias 81 | fmodel.conv1.att.data = model.conv1.att.data 82 | 83 | x = data.x.clone().requires_grad_(True) 84 | x_f = data.x.clone().requires_grad_(True) 85 | 86 | # fmodel.conv1.dense.bias = torch.nn.Parameter(model.conv1.bias) 87 | # fmodel.conv1.dense.weight = model.conv1.dense.weight 88 | # fmodel.conv1.dense.bias = model.conv1.dense.bias 89 | # fmodel.to(device) 90 | ref, dp_mask, dp_mask_self = model(x=x) 91 | dp_mask = dp_mask.detach() 92 | dp_mask_self = dp_mask_self.detach() 93 | f_res = fmodel(x=x_f, dp_mask=dp_mask, dp_mask_self=dp_mask_self) 94 | grad = torch.rand_like(f_res) 95 | f_res.backward(grad) 96 | fgrad = x_f.grad 97 | ref.backward(grad) 98 | refgrad = x.grad 99 | error = torch.abs(f_res - ref) 100 | graderror = torch.abs(fgrad - refgrad) 101 | 102 | fattgrad = fmodel.conv1.att.grad 103 | refattgrad = model.conv1.att.grad 104 | attgraderror = torch.abs(fattgrad - refattgrad) 105 | 106 | # error = error.ge(1e-5).to(torch.float32) 107 | # print(error) 108 | max_error = torch.max(error).item() 109 | passed = True 110 | if max_error > 1e-5 or np.isnan(max_error): 111 | print("[Forward] on %s, edge_permute %r, hidden size %d, flow %s" % (data_name, edge_permute, hidden, flow)) 112 | print("there are %d different entries in overall %d entires. The maximum difference is %f" % 113 | (torch.nonzero(error).size(0), error.numel(), max_error)) 114 | passed = False 115 | max_graderror = torch.max(graderror).item() 116 | if max_graderror > 5e-5 or np.isnan(max_graderror): 117 | print("[Backward] on %s, edge_permute %r, hidden size %d, flow %s" % (data_name, edge_permute, hidden, flow)) 118 | print("there are %d different entries in overall %d entires. The maximum difference is %f" % 119 | (torch.nonzero(graderror).size(0), graderror.numel(), max_graderror)) 120 | passed = False 121 | max_attgraderror = torch.max(attgraderror).item() 122 | if max_attgraderror > 1e-5 or np.isnan(max_attgraderror): 123 | print("[ATT] on %s, edge_permute %r, hidden size %d, flow %s" % (data_name, edge_permute, hidden, flow)) 124 | print("there are %d different entries in overall %d entires. The maximum difference is %f" % 125 | (torch.nonzero(attgraderror).size(0), attgraderror.numel(), max_attgraderror)) 126 | passed = False 127 | return passed 128 | 129 | 130 | num_exp = 0 131 | num_pass = 0 132 | 133 | for ds in datasets: 134 | for ep in edge_permutes: 135 | for f in flows: 136 | for i in range(6): 137 | num_exp += 1 138 | if single_test(ds, ep, f): 139 | num_pass += 1 140 | 141 | print("%d out of %d tests passed" % (num_pass, num_exp)) -------------------------------------------------------------------------------- /src/cuda/gcn_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | 8 | /* 9 | * get degree 10 | */ 11 | 12 | #define THREADS 256 13 | #define BLOCKS(N, T) (N + T - 1)/T 14 | 15 | // when using CSR/CSC format 16 | template 17 | __global__ void from_ptr( 18 | int* __restrict__ tar_ptr, 19 | scalar_t* __restrict__ degree, 20 | int num_nodes 21 | ){ 22 | unsigned int stride = blockDim.x * gridDim.x; 23 | 24 | for (unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < num_nodes; tid += stride){ 25 | degree[tid] = tar_ptr[tid + 1] - tar_ptr[tid] + 1; 26 | } 27 | } 28 | 29 | 30 | template 31 | __device__ void smem_reduction(volatile scalar_t* sdata, unsigned int tid){ 32 | if (blockSize >= 1024){ 33 | if (tid < 512){ 34 | sdata[tid] += sdata[tid + 512]; 35 | } 36 | __syncthreads(); 37 | } 38 | if (blockSize >= 512){ 39 | if (tid < 256){ 40 | sdata[tid] += sdata[tid + 256]; 41 | } 42 | __syncthreads(); 43 | } 44 | if (blockSize >= 256){ 45 | if (tid < 128){ 46 | sdata[tid] += sdata[tid + 128]; 47 | } 48 | __syncthreads(); 49 | } 50 | if (blockSize >= 128){ 51 | if (tid < 64){ 52 | sdata[tid] += sdata[tid + 64]; 53 | } 54 | __syncthreads(); 55 | } 56 | if (tid < 32){ 57 | if (blockSize >= 64)sdata[tid] += sdata[tid + 32]; 58 | if (blockSize >= 32)sdata[tid] += sdata[tid + 16]; 59 | if (blockSize >= 16)sdata[tid] += sdata[tid + 8]; 60 | if (blockSize >= 8)sdata[tid] += sdata[tid + 4]; 61 | if (blockSize >= 4)sdata[tid] += sdata[tid + 2]; 62 | if (blockSize >= 2)sdata[tid] += sdata[tid + 1]; 63 | } 64 | } 65 | 66 | 67 | template 68 | __global__ void from_weight( 69 | scalar_t* __restrict__ edge_weight, 70 | scalar_t* __restrict__ degree, 71 | int* __restrict__ tar_ptr 72 | ){ 73 | __shared__ scalar_t deg[blockSize]; 74 | unsigned int tid = threadIdx.x; 75 | deg[tid] = 0; 76 | unsigned int tar_id = blockIdx.x; 77 | for (unsigned int e_idx = tar_ptr[tar_id] + threadIdx.x; e_idx < tar_ptr[tar_id + 1]; e_idx += blockDim.x){ 78 | deg[tid] += edge_weight[e_idx]; 79 | } 80 | __syncthreads(); 81 | smem_reduction(deg, tid); 82 | if (tid == 0) degree[tar_id] = deg[0] + 1; 83 | } 84 | 85 | // When using COO format 86 | template 87 | __global__ void scatter_add( 88 | scalar_t* __restrict__ edge_weight, 89 | scalar_t* __restrict__ degree, 90 | int* __restrict__ tar_index, 91 | unsigned int num_edge 92 | ){ 93 | unsigned int stride = blockDim.x * gridDim.x; 94 | for (unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < num_edge; tid += stride){ 95 | atomicAdd(°ree[tar_index[tid]], edge_weight[tid]); 96 | } 97 | } 98 | 99 | torch::Tensor get_degree_cuda( 100 | torch::Tensor tar_ptr, 101 | torch::Tensor src_index, 102 | torch::optional optional_edge_weight, 103 | int num_nodes, 104 | bool tar_to_src 105 | ){ 106 | auto options = torch::TensorOptions().dtype(torch::kFloat32).device(tar_ptr.device()); 107 | 108 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 109 | 110 | torch::Tensor degree; 111 | 112 | if (optional_edge_weight.has_value()){ 113 | degree = torch::empty({num_nodes,}, options); 114 | torch::Tensor edge_weight; 115 | edge_weight = optional_edge_weight.value().contiguous(); 116 | if (tar_to_src){ 117 | AT_DISPATCH_FLOATING_TYPES(degree.type(), "get degree from weight", ([&]{ 118 | from_weight<<>>( 119 | edge_weight.data(), degree.data(), 120 | tar_ptr.data()); 121 | })); 122 | }else{ 123 | degree = torch::ones({num_nodes,}, options); 124 | 125 | AT_DISPATCH_FLOATING_TYPES(degree.type(), "get degree from weight scatter", ([&]{ 126 | scatter_add<<>>( 127 | edge_weight.data(), degree.data(), 128 | src_index.data(), edge_weight.size(0)); 129 | })); 130 | } 131 | }else{ 132 | degree = torch::empty({num_nodes,}, options); 133 | AT_DISPATCH_FLOATING_TYPES(degree.type(), "get degree from ptr", ([&]{ 134 | from_ptr<<>>( 135 | tar_ptr.data(), degree.data(), num_nodes 136 | ); 137 | })); 138 | } 139 | 140 | return degree; 141 | } 142 | 143 | 144 | template 145 | __global__ void update_weight( 146 | const int* __restrict__ src_index, 147 | const int* __restrict__ tar_index, 148 | const scalar_t* __restrict__ edge_weight, 149 | scalar_t* __restrict__ out_edge_weight, 150 | const scalar_t* __restrict__ degree, 151 | int num_edge 152 | ){ 153 | for(unsigned int tid=blockIdx.x * blockDim.x + threadIdx.x; tid < num_edge; tid += gridDim.x * blockDim.x){ 154 | scalar_t res = edge_weight[tid] / sqrtf(degree[src_index[tid]])/sqrtf(degree[tar_index[tid]]); 155 | if (isinf(res)) res = 0; 156 | out_edge_weight[tid] = res; 157 | } 158 | } 159 | 160 | 161 | template 162 | __global__ void get_weight( 163 | const int* __restrict__ src_index, 164 | const int* __restrict__ tar_index, 165 | scalar_t* __restrict__ edge_weight, 166 | scalar_t* __restrict__ degree, 167 | int num_edge 168 | ){ 169 | for(unsigned int tid=blockIdx.x * blockDim.x + threadIdx.x; tid < num_edge; tid += gridDim.x * blockDim.x){ 170 | scalar_t res = 1/sqrtf(degree[src_index[tid]] * degree[tar_index[tid]]); 171 | if (isinf(res)) res = 0; 172 | edge_weight[tid] = res; 173 | } 174 | } 175 | 176 | 177 | // CUDA Edge processing declaration 178 | std::vector gcn_gar_egde_weight_cuda( 179 | torch::Tensor src_index, 180 | torch::Tensor tar_ptr, 181 | torch::Tensor tar_index, 182 | int num_nodes, 183 | torch::optional optional_edge_weight, 184 | bool tar_to_src 185 | ){ 186 | // Step 1: get degree 187 | auto degree = get_degree_cuda(tar_ptr, src_index, optional_edge_weight, num_nodes, tar_to_src); 188 | 189 | // Step 3: initialize the edge_weight with 1s if not provided 190 | unsigned int Ne = src_index.size(0); 191 | torch::Tensor edge_weight; 192 | 193 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 194 | 195 | if (optional_edge_weight.has_value()){ 196 | edge_weight = optional_edge_weight.value().contiguous(); 197 | auto out_edge_weight = torch::empty_like(edge_weight); 198 | AT_DISPATCH_FLOATING_TYPES(edge_weight.type(), "update weight", ([&]{ 199 | update_weight<<>>( 200 | src_index.data(), tar_index.data(), 201 | edge_weight.data(), 202 | out_edge_weight.data(), 203 | degree.data(), src_index.size(0) 204 | ); 205 | })); 206 | return {out_edge_weight, degree}; 207 | }else{ 208 | auto options = torch::TensorOptions().dtype(torch::kFloat32).device(src_index.device()); 209 | edge_weight = torch::empty({Ne,}, options); 210 | AT_DISPATCH_FLOATING_TYPES(edge_weight.type(), "update weight", ([&]{ 211 | get_weight<<>>( 212 | src_index.data(), tar_index.data(), 213 | edge_weight.data(), 214 | degree.data(), src_index.size(0) 215 | ); 216 | })); 217 | return {edge_weight, degree}; 218 | } 219 | } 220 | 221 | 222 | // CUDA Edge processing declaration 223 | std::vector gcn_gas_edge_weight_cuda( 224 | torch::Tensor src_index, 225 | torch::Tensor tar_index, 226 | int num_nodes, 227 | torch::optional optional_edge_weight, 228 | bool tar_to_src 229 | ){ 230 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 231 | unsigned int Ne = src_index.size(0); 232 | // Step 0: get edge_weight 233 | torch::Tensor edge_weight; 234 | 235 | if (optional_edge_weight.has_value()){ 236 | edge_weight = optional_edge_weight.value().contiguous(); 237 | }else{ 238 | auto options_w = torch::TensorOptions().dtype(torch::kFloat32).device(src_index.device()); 239 | edge_weight = torch::ones({Ne,}, options_w); 240 | } 241 | // Step 1: get degree 242 | auto options_d = torch::TensorOptions().dtype(torch::kFloat32).device(tar_index.device()); 243 | 244 | auto degree = torch::ones({num_nodes,}, options_d); 245 | 246 | if (tar_to_src){ 247 | AT_DISPATCH_FLOATING_TYPES(degree.type(), "get degree from weight scatter", ([&]{ 248 | scatter_add<<>>( 249 | edge_weight.data(), degree.data(), 250 | tar_index.data(), edge_weight.size(0)); 251 | })); 252 | }else{ 253 | AT_DISPATCH_FLOATING_TYPES(degree.type(), "get degree from weight scatter", ([&]{ 254 | scatter_add<<>>( 255 | edge_weight.data(), degree.data(), 256 | src_index.data(), edge_weight.size(0)); 257 | })); 258 | } 259 | auto out_edge_weight = torch::empty_like(edge_weight); 260 | 261 | AT_DISPATCH_FLOATING_TYPES(edge_weight.type(), "update weight", ([&]{ 262 | update_weight<<>>( 263 | src_index.data(), tar_index.data(), 264 | edge_weight.data(), 265 | out_edge_weight.data(), 266 | degree.data(), src_index.size(0) 267 | ); 268 | })); 269 | 270 | return {out_edge_weight, degree}; 271 | } 272 | -------------------------------------------------------------------------------- /fuseGNN/convs/gcn_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_scatter 3 | from torch_geometric.nn import GCNConv as geoGCNConv 4 | from fuseGNN.functional import coo2csr, gcn_gar_edge_weight, csr2csc, fused_gar_agg , gcn_gas_edge_weight, fused_gas_agg 5 | import torch.nn.functional as F 6 | 7 | 8 | # the reference module 9 | 10 | 11 | class refGCNConv(torch.nn.Module): 12 | def __init__(self, in_channels, out_channels, cached=False, bias=True, 13 | flow='target_to_source'): 14 | """ 15 | Args: 16 | in_channels (int): Size of each input sample 17 | out_channels (int): Size of each output sample. 18 | cached (bool, optional): if set to True, the layer will cache 19 | the computation of D^{-0.5}\hat{A}of D^{-0.5} on the first 20 | execution, and it will be used for further executions. 21 | This is only helpful in transductive learning 22 | bias (bool, optional): If set to True, there will be a bias 23 | flow (str): could be the following two conditions 24 | 'source_to_target': edge_index[0] is the source nodes, [1] is the target nodes 25 | 'target_to_source': edge_index[0] is the target nodes, [0] is the source nodes 26 | fused (bool, optional): If set to True, the gcnlib.gcn_aggregate_f_cuda will be used. Default: False, 27 | verify (bool, optional): If set to True, it will output the difference between fused and unfused version 28 | """ 29 | super(refGCNConv, self).__init__() 30 | self.flow = flow 31 | assert self.flow in ['source_to_target', 'target_to_source'] 32 | self.tid, self.sid = (0, 1) if self.flow == 'target_to_source' else (1, 0) 33 | 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | # The dense layer for Update stage 37 | self.dense = torch.nn.Linear(in_features=in_channels, out_features=out_channels, bias=bias) 38 | 39 | self.cached = cached 40 | self.cached_num_edges = None 41 | self.cached_deg = None 42 | self.cached_deg_int = None 43 | self.cached_weight = None 44 | self.cached_edge_ptr = None 45 | 46 | def forward(self, x, edge_index, edge_weight=None): 47 | """ 48 | Args: 49 | x (float32 [N(v), in_channels]) : matrix of feature vectors 50 | edge_index (int64 [2, N(v)]) : the list of edges 51 | edge_weight (float32 [N(v)], optional) : the weight on each edge, default: 1 52 | """ 53 | # step 0: split the edge index and convert them to int32 54 | src_index, tar_index = (edge_index[self.sid], edge_index[self.tid]) 55 | 56 | # step 1: Update, forward input features into linear layer 57 | x = self.dense(x) 58 | 59 | # step 2: getting the D^{-0.5}\hat{A}of D^{-0.5} matrix 60 | self.get_adj(src_index, tar_index, x.size(0), edge_weight) 61 | 62 | # Aggregation stage 63 | return self.propagate(x, src_index, tar_index) 64 | 65 | def get_adj(self, src_index, tar_index, num_nodes, edge_weight=None): 66 | """ 67 | Args: 68 | src_index (int64 [N(e)]) : COO index of source 69 | tar_index (int64 [N(e)]) : COO index of target 70 | num_nodes (int64): number of nodes in the input graph 71 | edge_weight (float32 [N(e)], optional) : the weight on each edge, default: 1 72 | """ 73 | if self.cached and self.cached_num_edges is not None: 74 | # when the result is cached, and this is not the first execution 75 | if src_index.size(0) != self.cached_num_edges: 76 | raise RuntimeError( 77 | 'Chached {} number of edges, but found {}. Please ' 78 | 'disable the caching behavior of this layer by removing ' 79 | 'the cached=True argument in its consturctor'.format( 80 | self.cached_num_edges, src_index.size(0))) 81 | 82 | if not self.cached or self.cached_num_edges is None: 83 | # when the result is not cached, or its the first execution 84 | self.cached_num_edges = src_index.size(0) 85 | # update the edge weight based on degree information 86 | self.processing_edge(src_index, tar_index, num_nodes, edge_weight) 87 | 88 | def processing_edge(self, src_index, tar_index, num_nodes, edge_weight=None): 89 | """ 90 | Update the edge_weights with degree information 91 | w = w/\sqrt((d_s + 1)(d_t + 1)) 92 | Remark: we don't add the self-loops into the edges, instead we will add them manually. 93 | Args: 94 | src_index (int64 [N(e)]) : COO index of source 95 | tar_index (int64 [N(e)]) : COO index of target 96 | num_nodes (int64): number of nodes in the input graph 97 | edge_weight (float32 [N(v)], optional) : the weight on each edge, default: 1 98 | """ 99 | if edge_weight is None: 100 | edge_weight = torch.ones(size=(src_index.size(0),), dtype=torch.float32, device=src_index.device) 101 | 102 | # The index to get degree is the first vector of edge_index 103 | if self.flow == "source_to_target": 104 | deg = torch_scatter.scatter_add(src=edge_weight, index=src_index, dim=0, dim_size=num_nodes) + 1 105 | else: 106 | deg = torch_scatter.scatter_add(src=edge_weight, index=tar_index, dim=0, dim_size=num_nodes) + 1 107 | deg_inv_sqrt = deg.pow(-0.5) 108 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 109 | edge_weight = deg_inv_sqrt[src_index] * edge_weight * deg_inv_sqrt[tar_index] 110 | self.cached_deg = deg 111 | self.cached_weight = edge_weight 112 | self.cached_num_edges = src_index.size(0) 113 | 114 | def propagate(self, feature, src_index, tar_index): 115 | """ 116 | Args: 117 | feature: the feature vectors float32 [N(v), dim] 118 | edge_index: the list of edges [2, N(e)] 119 | """ 120 | # step 1: scatter the feature vectors to the extended feature map for reduction 121 | extended_fm = torch.index_select(feature, dim=0, index=src_index) 122 | 123 | # step 2: apply the weights 124 | extended_fm = extended_fm * self.cached_weight.unsqueeze(1) 125 | 126 | # step 3: scatter to the output feature 127 | out = torch_scatter.scatter_add(src=extended_fm, index=tar_index, dim=0, dim_size=feature.size(0)) 128 | 129 | # step 4: apply self-loops 130 | out += feature * self.cached_deg.pow(-1).unsqueeze(1) 131 | 132 | return out 133 | 134 | 135 | 136 | # the final fused version 137 | 138 | class garGCNConv(torch.nn.Module): 139 | def __init__(self, in_channels, out_channels, cached=False, bias=True, 140 | flow='target_to_source'): 141 | """ 142 | Args: 143 | in_channels (int): Size of each input sample 144 | out_channels (int): Size of each output sample. 145 | cached (bool, optional): if set to True, the layer will cache 146 | the computation of D^{-0.5}\hat{A}of D^{-0.5} on the first 147 | execution, and it will be used for further executions. 148 | This is only helpful in transductive learning 149 | bias (bool, optional): If set to True, there will be a bias 150 | flow (str): could be the following two conditions 151 | 'source_to_target': edge_index[0] is the source nodes, [1] is the target nodes 152 | 'target_to_source': edge_index[0] is the target nodes, [0] is the source nodes 153 | """ 154 | super(garGCNConv, self).__init__() 155 | self.flow = flow 156 | assert self.flow in ['source_to_target', 'target_to_source'] 157 | 158 | self.tid, self.sid = (0, 1) if self.flow == 'target_to_source' else (1, 0) 159 | 160 | self.in_channels = in_channels 161 | self.out_channels = out_channels 162 | # The dense layer for Update stage 163 | self.dense = torch.nn.Linear(in_features=in_channels, out_features=out_channels, bias=bias) 164 | 165 | self.cached = cached 166 | 167 | self.cached_tar_ptr = None 168 | self.cached_src_index = None 169 | self.cached_edge_weight_f = None 170 | self.cached_self_edge_weight = None 171 | self.cached_src_ptr = None 172 | self.cached_tar_index = None 173 | self.cached_edge_weight_b = None 174 | 175 | self.cached_num_edges = None 176 | 177 | # It seems that using multiple streams doesn't help 178 | # self.stream1 = torch.cuda.Stream() 179 | # self.stream2 = torch.cuda.Stream() 180 | 181 | def forward(self, x, edge_index=None, edge_weight=None, self_edge_weight=None, tar_ptr=None, 182 | src_index=None, src_ptr=None, tar_index=None, edge_weight_b=None): 183 | 184 | x = self.dense(x) 185 | if not self.cached or self.cached_num_edges is None: 186 | # when the results are not cached, or it is the first execution. 187 | if tar_ptr is not None: # when the CSR & CSC format are provided 188 | self.cached_tar_ptr = tar_ptr 189 | self.cached_src_index = src_index 190 | self.cached_edge_weight_f = edge_weight 191 | self.cached_self_edge_weight = self_edge_weight 192 | self.cached_src_ptr = src_ptr 193 | self.cached_tar_index = tar_index 194 | self.cached_edge_weight_b = edge_weight_b 195 | 196 | self.cached_num_edges = tar_ptr.size(0) 197 | else: 198 | num_nodes = x.size(0) 199 | # convert the edge lists to int32 200 | edge_index = edge_index.to(torch.int32) 201 | src_index, tar_index = (edge_index[self.sid], edge_index[self.tid]) 202 | # convert coo format to csr format 203 | self.cached_src_index, tar_index, self.cached_tar_ptr, edge_weight_f = coo2csr(src_index, tar_index, 204 | num_nodes, edge_weight, False) 205 | # update edge weight 206 | self.cached_edge_weight_f, self.cached_self_edge_weight = gcn_gar_edge_weight(self.cached_src_index, 207 | self.cached_tar_ptr, tar_index, 208 | num_nodes, edge_weight_f, 209 | self.flow) 210 | # get the csc format for backward pass 211 | self.cached_src_ptr, self.cached_tar_index, self.cached_edge_weight_b = csr2csc(self.cached_tar_ptr, 212 | self.cached_src_index, 213 | self.cached_edge_weight_f, 214 | num_nodes) 215 | self.cached_num_edges = self.cached_tar_ptr.size(0) 216 | 217 | return fused_gar_agg(feature=x, src_index=self.cached_src_index, tar_ptr=self.cached_tar_ptr, 218 | edge_weight_f=self.cached_edge_weight_f, self_edge_weight=self.cached_self_edge_weight, 219 | tar_index=self.cached_tar_index, src_ptr=self.cached_src_ptr, 220 | edge_weight_b=self.cached_edge_weight_b, require_edge_weight=False) 221 | 222 | 223 | # the GAS version 224 | 225 | class gasGCNConv(torch.nn.Module): 226 | def __init__(self, in_channels, out_channels, cached=False, bias=True, 227 | flow='target_to_source'): 228 | """ 229 | Args: 230 | in_channels (int): Size of each input sample 231 | out_channels (int): Size of each output sample. 232 | cached (bool, optional): if set to True, the layer will cache 233 | the computation of D^{-0.5}\hat{A}of D^{-0.5} on the first 234 | execution, and it will be used for further executions. 235 | This is only helpful in transductive learning 236 | bias (bool, optional): If set to True, there will be a bias 237 | flow (str): could be the following two conditions 238 | 'source_to_target': edge_index[0] is the source nodes, [1] is the target nodes 239 | 'target_to_source': edge_index[0] is the target nodes, [0] is the source nodes 240 | """ 241 | super(gasGCNConv, self).__init__() 242 | self.flow = flow 243 | assert self.flow in ['source_to_target', 'target_to_source'] 244 | 245 | self.tid, self.sid = (0, 1) if self.flow == 'target_to_source' else (1, 0) 246 | 247 | self.in_channels = in_channels 248 | self.out_channels = out_channels 249 | # The dense layer for Update stage 250 | self.dense = torch.nn.Linear(in_features=in_channels, out_features=out_channels, bias=bias) 251 | 252 | self.cached = cached 253 | self.cached_num_edges = None 254 | 255 | self.cached_src_index = None 256 | self.cached_tar_index = None 257 | self.cached_edge_weight = None 258 | self.cached_self_edge_weight = None 259 | # It seems that using multiple streams doesn't help 260 | # self.stream1 = torch.cuda.Stream() 261 | # self.stream2 = torch.cuda.Stream() 262 | 263 | def forward(self, x, edge_index=None, edge_weight=None, src_index=None, tar_index=None, self_edge_weight=None): 264 | x = self.dense(x) 265 | if not self.cached or self.cached_num_edges is None: 266 | if self_edge_weight is not None: 267 | self.cached_src_index = src_index 268 | self.cached_tar_index = tar_index 269 | self.cached_edge_weight = edge_weight 270 | self.cached_self_edge_weight = self_edge_weight 271 | self.cached_num_edges = self.cached_src_index.size(0) 272 | else: 273 | num_nodes = x.size(0) 274 | edge_index = edge_index.to(torch.int32) 275 | self.cached_src_index, self.cached_tar_index = (edge_index[self.sid], edge_index[self.tid]) 276 | self.cached_edge_weight, self.cached_self_edge_weight = gcn_gas_edge_weight(self.cached_src_index, 277 | self.cached_tar_index, 278 | num_nodes, 279 | edge_weight, 280 | self.flow) 281 | self.cached_num_edges = self.cached_src_index.size(0) 282 | return fused_gas_agg(feature=x, src_index=self.cached_src_index, tar_index=self.cached_tar_index, 283 | edge_weight=self.cached_edge_weight, self_edge_weight=self.cached_self_edge_weight, 284 | require_edge_weight=False) 285 | -------------------------------------------------------------------------------- /src/cuda/gat_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | 8 | #define BLOCKS(N, T) (N + T - 1)/T 9 | 10 | 11 | template 12 | __device__ void smem_reduce_v2(volatile scalar_t* sdata, unsigned int tid, unsigned int reduce_len){ 13 | if (reduce_len > 512){ 14 | if(tid < 512){ 15 | sdata[tid] += sdata[tid + 512]; 16 | } 17 | __syncthreads(); 18 | } 19 | if (reduce_len > 256){ 20 | if(tid < 256){ 21 | sdata[tid] += sdata[tid + 256]; 22 | } 23 | __syncthreads(); 24 | } 25 | if (reduce_len > 128){ 26 | if(tid < 128){ 27 | sdata[tid] += sdata[tid + 128]; 28 | } 29 | __syncthreads(); 30 | } 31 | if (reduce_len > 64){ 32 | if(tid < 64){ 33 | sdata[tid] += sdata[tid + 64]; 34 | } 35 | __syncthreads(); 36 | } 37 | if (tid < 32){ 38 | if (reduce_len > 32) sdata[tid] += sdata[tid + 32]; 39 | if (reduce_len > 16) sdata[tid] += sdata[tid + 16]; 40 | if (reduce_len > 8) sdata[tid] += sdata[tid + 8]; 41 | if (reduce_len > 4) sdata[tid] += sdata[tid + 4]; 42 | if (reduce_len > 2) sdata[tid] += sdata[tid + 2]; 43 | if (reduce_len > 1) sdata[tid] += sdata[tid + 1]; 44 | } 45 | } 46 | 47 | 48 | template 49 | __device__ void smem_reduce_v3(volatile scalar_t* sdata, unsigned int tid){ 50 | if (blockSize >= 1024){ 51 | if (tid < 512){ 52 | sdata[tid] += sdata[tid + 512]; 53 | } 54 | __syncthreads(); 55 | } 56 | if (blockSize >= 512){ 57 | if (tid < 256){ 58 | sdata[tid] += sdata[tid + 256]; 59 | } 60 | __syncthreads(); 61 | } 62 | if (blockSize >= 256){ 63 | if (tid < 128){ 64 | sdata[tid] += sdata[tid + 128]; 65 | } 66 | __syncthreads(); 67 | } 68 | if (blockSize >= 128){ 69 | if (tid < 64){ 70 | sdata[tid] += sdata[tid + 64]; 71 | } 72 | __syncthreads(); 73 | } 74 | if (tid < 32){ 75 | if (blockSize >= 64) sdata[tid] += sdata[tid + 32]; 76 | if (blockSize >= 32) sdata[tid] += sdata[tid + 16]; 77 | if (blockSize >= 16) sdata[tid] += sdata[tid + 8]; 78 | if (blockSize >= 8) sdata[tid] += sdata[tid + 4]; 79 | if (blockSize >= 4) sdata[tid] += sdata[tid + 2]; 80 | if (blockSize >= 2) sdata[tid] += sdata[tid + 1]; 81 | } 82 | } 83 | 84 | 85 | template 86 | __global__ void gar_get_alpha( 87 | const torch::PackedTensorAccessor e_pre, 88 | const int* __restrict__ src_ptr, 89 | const int* __restrict__ tar_index, 90 | scalar_t* alpha, 91 | scalar_t* mask_lrelu, 92 | scalar_t* e_sum, 93 | scalar_t* e_, 94 | float negative_slope 95 | ){ 96 | // each thread block handles a single target 97 | unsigned int src_id = blockIdx.x; 98 | unsigned int tid = threadIdx.x; 99 | unsigned int e_start = src_ptr[src_id]; 100 | unsigned int e_bound = src_ptr[src_id + 1]; 101 | scalar_t e_pre_src = e_pre[src_id][0]; 102 | scalar_t e = 0; 103 | 104 | for (unsigned int e_id = e_start + tid; e_id < e_bound; e_id += blockDim.x){ 105 | unsigned int tar_id = tar_index[e_id]; 106 | // add attention factors 107 | e = e_pre_src + e_pre[tar_id][1]; 108 | // leaky RelU 109 | if (e < 0){ 110 | e *= negative_slope; 111 | mask_lrelu[e_id] = negative_slope; 112 | }else{ 113 | mask_lrelu[e_id] = 1; 114 | } 115 | e = exp(e); 116 | atomicAdd(&e_sum[tar_id], e); 117 | alpha[e_id] = e; 118 | e_[e_id] = e; 119 | } 120 | } 121 | 122 | 123 | template 124 | __global__ void get_alpha_self( 125 | const torch::PackedTensorAccessor e_pre, 126 | scalar_t* e_sum, scalar_t* e_self, scalar_t* mask_lrelu_self, scalar_t* alpha_self, 127 | unsigned int num_node, float negative_slope 128 | ){ 129 | for (unsigned int src_id = blockIdx.x * blockDim.x + threadIdx.x; src_id < num_node; src_id += blockDim.x * gridDim.x){ 130 | scalar_t e = e_pre[src_id][0] + e_pre[src_id][1]; 131 | if (e < 0){ 132 | e *= negative_slope; 133 | mask_lrelu_self[src_id] = negative_slope; 134 | }else{ 135 | mask_lrelu_self[src_id] = 1; 136 | } 137 | e = exp(e); 138 | e_self[src_id] = e; 139 | scalar_t sum = e_sum[src_id] + e; 140 | e_sum[src_id] = sum; 141 | alpha_self[src_id] = e / sum; 142 | } 143 | } 144 | 145 | 146 | #define THREADS 256 147 | 148 | 149 | template 150 | __global__ void alpha_normalize_kernel( 151 | scalar_t* alpha, 152 | const scalar_t* __restrict__ e_sum, 153 | const int* __restrict__ tar_index, 154 | unsigned int num_edge 155 | ){ 156 | for (unsigned int e_id = blockIdx.x * blockDim.x + threadIdx.x; e_id < num_edge; e_id += blockDim.x * gridDim.x){ 157 | alpha[e_id] /= (e_sum[tar_index[e_id]]); 158 | } 159 | } 160 | 161 | 162 | std::vector gat_gar_edge_weight_cuda( 163 | torch::Tensor e_pre, 164 | torch::Tensor src_ptr, 165 | torch::Tensor tar_index, 166 | float negative_slope 167 | ){ 168 | unsigned int num_node = e_pre.size(0); 169 | unsigned int num_edge = tar_index.size(0); 170 | 171 | auto options = torch::TensorOptions().dtype(torch::kFloat32).device(e_pre.device()); 172 | auto alpha = torch::empty({num_edge, 1}, options); 173 | auto mask_lrelu = torch::empty_like(alpha); 174 | auto alpha_self = torch::empty({num_node, 1}, options); 175 | auto e_sum = torch::zeros_like(alpha_self); 176 | auto mask_lrelu_self = torch::empty_like(alpha_self); 177 | 178 | auto e = torch::empty_like(alpha); 179 | auto e_self = torch::empty_like(alpha_self); 180 | 181 | AT_DISPATCH_FLOATING_TYPES(e_pre.type(), "gcn_aggregate_f_kernel", ([&]{ 182 | gar_get_alpha<<>>( 183 | e_pre.packed_accessor(), 184 | src_ptr.data(), tar_index.data(), 185 | alpha.data(), mask_lrelu.data(), 186 | e_sum.data(), e.data(), negative_slope 187 | ); 188 | })); 189 | 190 | AT_DISPATCH_FLOATING_TYPES(e_pre.type(), "gcn_aggregate_f_kernel", ([&]{ 191 | get_alpha_self<<>>( 192 | e_pre.packed_accessor(), 193 | e_sum.data(), e_self.data(), mask_lrelu_self.data(), 194 | alpha_self.data(), num_node, negative_slope 195 | ); 196 | })); 197 | 198 | AT_DISPATCH_FLOATING_TYPES(e_pre.type(), "gcn_aggregate_f_kernel", ([&]{ 199 | alpha_normalize_kernel<<>>( 200 | alpha.data(), e_sum.data(), 201 | tar_index.data(), num_edge 202 | ); 203 | })); 204 | 205 | return {alpha, alpha_self, mask_lrelu, mask_lrelu_self, e_sum, e, e_self}; 206 | } 207 | 208 | 209 | template 210 | __global__ void gas_get_alpha( 211 | const torch::PackedTensorAccessor e_pre, 212 | const int* __restrict__ src_index, 213 | const int* __restrict__ tar_index, 214 | scalar_t* alpha, 215 | scalar_t* mask_lrelu, 216 | scalar_t* e_sum, 217 | scalar_t* e_, 218 | float negative_slope, 219 | unsigned int num_edge 220 | ){ 221 | for (unsigned int e_id = blockIdx.x * blockDim.x + threadIdx.x; e_id < num_edge; e_id += blockDim.x * gridDim.x){ 222 | unsigned int tar_id = tar_index[e_id]; 223 | unsigned int src_id = src_index[e_id]; 224 | scalar_t e = e_pre[src_id][0] + e_pre[tar_id][1]; 225 | if (e < 0){ 226 | e *= negative_slope; 227 | mask_lrelu[e_id] = negative_slope; 228 | }else{ 229 | mask_lrelu[e_id] = 1; 230 | } 231 | e = exp(e); 232 | atomicAdd(&e_sum[tar_id], e); 233 | alpha[e_id] = e; 234 | e_[e_id] = e; 235 | } 236 | } 237 | 238 | 239 | 240 | std::vector gat_gas_edge_weight_cuda( 241 | torch::Tensor e_pre, 242 | torch::Tensor src_index, 243 | torch::Tensor tar_index, 244 | float negative_slope 245 | ){ 246 | unsigned int num_node = e_pre.size(0); 247 | unsigned int num_edge = tar_index.size(0); 248 | auto options = torch::TensorOptions().dtype(torch::kFloat32).device(e_pre.device()); 249 | auto alpha = torch::empty({num_edge, 1}, options); 250 | auto mask_lrelu = torch::empty_like(alpha); 251 | auto alpha_self = torch::empty({num_node, 1}, options); 252 | auto e_sum = torch::zeros_like(alpha_self); 253 | auto mask_lrelu_self = torch::empty_like(alpha_self); 254 | 255 | auto e = torch::empty_like(alpha); 256 | auto e_self = torch::empty_like(alpha_self); 257 | 258 | AT_DISPATCH_FLOATING_TYPES(e_pre.type(), "gcn_aggregate_f_kernel", ([&]{ 259 | gas_get_alpha<<>>( 260 | e_pre.packed_accessor(), 261 | src_index.data(), tar_index.data(), 262 | alpha.data(), mask_lrelu.data(), 263 | e_sum.data(), e.data(), negative_slope, num_edge 264 | ); 265 | })); 266 | 267 | AT_DISPATCH_FLOATING_TYPES(e_pre.type(), "gcn_aggregate_f_kernel", ([&]{ 268 | get_alpha_self<<>>( 269 | e_pre.packed_accessor(), 270 | e_sum.data(), e_self.data(), mask_lrelu_self.data(), 271 | alpha_self.data(), num_node, negative_slope 272 | ); 273 | })); 274 | 275 | AT_DISPATCH_FLOATING_TYPES(e_pre.type(), "gcn_aggregate_f_kernel", ([&]{ 276 | alpha_normalize_kernel<<>>( 277 | alpha.data(), e_sum.data(), 278 | tar_index.data(), num_edge 279 | ); 280 | })); 281 | 282 | return {alpha, alpha_self, mask_lrelu, mask_lrelu_self, e_sum, e, e_self}; 283 | } 284 | 285 | 286 | 287 | 288 | template 289 | __global__ void gat_gar_edge_weight_b_kernel( 290 | const scalar_t* __restrict__ grad_alpha, 291 | const scalar_t* __restrict__ alpha, 292 | const int* __restrict__ tar_index, 293 | const scalar_t* __restrict__ e_sum, 294 | scalar_t* grad_e_sum, 295 | scalar_t* grad_e, 296 | unsigned int num_edge 297 | ){ 298 | for (unsigned int e_id = blockIdx.x * blockDim.x + threadIdx.x; e_id < num_edge; e_id += blockDim.x * gridDim.x){ 299 | unsigned int tar_id = tar_index[e_id]; 300 | scalar_t e_sum_tar = e_sum[tar_id]; 301 | scalar_t grad_alpha_buffer = grad_alpha[e_id]; 302 | scalar_t g_e_sum = -(grad_alpha_buffer * alpha[e_id]); 303 | atomicAdd(&grad_e_sum[tar_id], g_e_sum); 304 | grad_e[e_id] = grad_alpha_buffer / e_sum_tar; 305 | } 306 | } 307 | 308 | template 309 | __global__ void gat_gar_edge_weight_b2_kernel( 310 | const scalar_t* __restrict__ grad_alpha_self, 311 | const scalar_t* __restrict__ alpha_self, 312 | const scalar_t* __restrict__ e_sum, 313 | scalar_t* grad_e_sum, //scalar_t* grad_e_self, 314 | const scalar_t* __restrict__ e_self, 315 | const scalar_t* __restrict__ mask_lrelu_self, 316 | torch::PackedTensorAccessor grad_e_pre, 317 | unsigned int num_node 318 | ){ 319 | for (unsigned int n_id = blockIdx.x * blockDim.x + threadIdx.x; n_id < num_node; n_id += blockDim.x * gridDim.x){ 320 | scalar_t e_sum_ = e_sum[n_id]; 321 | scalar_t grad_alpha_self_ = grad_alpha_self[n_id]; 322 | scalar_t grad_e_sum_ = (grad_e_sum[n_id] - grad_alpha_self_ * alpha_self[n_id]) / e_sum_; 323 | grad_e_sum[n_id] = grad_e_sum_; 324 | scalar_t grad_e_self = (grad_alpha_self_ / e_sum_ + grad_e_sum_) * e_self[n_id] * mask_lrelu_self[n_id]; 325 | grad_e_pre[n_id][0] = grad_e_self; 326 | grad_e_pre[n_id][1] = grad_e_self; 327 | } 328 | } 329 | 330 | 331 | template 332 | __global__ void gat_gar_edge_weight_b3_kernel( 333 | scalar_t* grad_e, 334 | const scalar_t* __restrict__ e, 335 | const scalar_t* __restrict__ grad_e_sum, 336 | const scalar_t* __restrict__ mask_lrelu, 337 | const int* __restrict__ tar_index, 338 | const int* __restrict__ src_index, 339 | torch::PackedTensorAccessor grad_e_pre, 340 | unsigned int num_edge 341 | ){ 342 | for (unsigned int e_id = blockIdx.x * blockDim.x + threadIdx.x; e_id < num_edge; e_id += blockDim.x * gridDim.x){ 343 | unsigned int tar_id = tar_index[e_id]; 344 | scalar_t grad_e_ = (grad_e[e_id] + grad_e_sum[tar_id]) * e[e_id] * mask_lrelu[e_id]; 345 | atomicAdd(&grad_e_pre[tar_id][1], grad_e_); 346 | atomicAdd(&grad_e_pre[src_index[e_id]][0], grad_e_); 347 | } 348 | 349 | } 350 | 351 | 352 | std::vector gat_gar_edge_weight_b_cuda( 353 | torch::Tensor grad_alpha_self, 354 | torch::Tensor grad_alpha, 355 | torch::Tensor src_index, 356 | torch::Tensor tar_index, 357 | torch::Tensor mask_lrelu, 358 | torch::Tensor mask_lrelu_self, 359 | torch::Tensor e, 360 | torch::Tensor e_self, 361 | torch::Tensor e_sum, 362 | torch::Tensor alpha_self, 363 | torch::Tensor alpha 364 | ){ 365 | unsigned int num_node = alpha_self.size(0); 366 | unsigned int num_edge = alpha.size(0); 367 | auto grad_e_sum = torch::zeros_like(e_sum); 368 | auto grad_e = torch::empty_like(e); 369 | // auto grad_e_self = torch::empty_like(e_self); 370 | 371 | auto options = torch::TensorOptions().dtype(torch::kFloat32).device(alpha.device()); 372 | auto grad_e_pre = torch::empty({num_node, 2}, options); 373 | 374 | AT_DISPATCH_FLOATING_TYPES(grad_alpha.type(), "gat_gar_edge_weight_b_kernel", ([&]{ 375 | gat_gar_edge_weight_b_kernel<<>>( 376 | grad_alpha.data(), alpha.data(), 377 | tar_index.data(), e_sum.data(), 378 | grad_e_sum.data(), grad_e.data(), num_edge 379 | ); 380 | })); 381 | 382 | AT_DISPATCH_FLOATING_TYPES(grad_alpha.type(), "gat_gar_edge_weight_b_kernel", ([&]{ 383 | gat_gar_edge_weight_b2_kernel<<>>( 384 | grad_alpha_self.data(), alpha_self.data(), 385 | e_sum.data(), grad_e_sum.data(), 386 | e_self.data(), mask_lrelu_self.data(), 387 | grad_e_pre.packed_accessor(), num_node 388 | ); 389 | })); 390 | 391 | AT_DISPATCH_FLOATING_TYPES(grad_alpha.type(), "gat_gar_edge_weight_b_kernel", ([&]{ 392 | gat_gar_edge_weight_b3_kernel<<>>( 393 | grad_e.data(), e.data(), grad_e_sum.data(), 394 | mask_lrelu.data(), tar_index.data(), 395 | src_index.data(), 396 | grad_e_pre.packed_accessor(), 397 | num_edge 398 | ); 399 | })); 400 | 401 | return {grad_e_sum, grad_e_pre}; 402 | } 403 | -------------------------------------------------------------------------------- /fuseGNN/convs/gat_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import GATConv as geoGATConv 3 | from fuseGNN.functional import fused_gar_agg, fused_gas_agg, csr2csc, coo2csr, gat_gar_edge_weight, gat_gas_edge_weight 4 | from fuseGNN.functional.format import Coo2Csr 5 | from torch.nn import Parameter 6 | import torch.nn.functional as F 7 | import torch_scatter 8 | import math 9 | from fuseGNN.functional import dropout as my_dropout 10 | 11 | # behavior module with pure python 12 | 13 | def glorot(tensor): 14 | if tensor is not None: 15 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 16 | tensor.data.uniform_(-stdv, stdv) 17 | 18 | 19 | def zeros(tensor): 20 | if tensor is not None: 21 | tensor.data.fill_(0) 22 | 23 | 24 | ############################################################################################### 25 | # Reference Module 26 | ############################################################################################### 27 | class refGATConv(torch.nn.Module): 28 | def __init__(self, in_channels, out_channels, heads=1, concat=True, 29 | negative_slope=0.2, dropout=0, bias=True, flow='target_to_source', return_mask=False, cached=False): 30 | """ 31 | Args: 32 | in_channels (int): Size of each input sample 33 | out_channels (int): Size of each output sample. 34 | cached (bool, optional): if set to True, the layer will cache 35 | the computation of D^{-0.5}\hat{A}of D^{-0.5} on the first 36 | execution, and it will be used for further executions. 37 | This is only helpful in transductive learning 38 | bias (bool, optional): If set to True, there will be a bias 39 | flow (str): could be the following two conditions 40 | 'source_to_target': edge_index[0] is the source nodes, [1] is the target nodes 41 | 'target_to_source': edge_index[0] is the target nodes, [0] is the source nodes 42 | fused (bool, optional): If set to True, the gcnlib.gcn_aggregate_f_cuda will be used. Default: False, 43 | verify (bool, optional): If set to True, it will output the difference between fused and unfused version 44 | """ 45 | super(refGATConv, self).__init__() 46 | self.flow = flow 47 | assert self.flow in ['source_to_target', 'target_to_source'] 48 | self.tid, self.sid = (0, 1) if self.flow == 'target_to_source' else (1, 0) 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.heads = heads 53 | # The dense layer for Update stage. The total number of out_freatures is out_channels * heads 54 | self.dense = torch.nn.Linear(in_features=in_channels, out_features=out_channels * heads, bias=bias) 55 | 56 | # The attention parameters 57 | self.att = Parameter(torch.Tensor(heads * out_channels, 2)) 58 | 59 | glorot(self.dense.weight) 60 | glorot(self.att) 61 | # zeros(self.bias) 62 | 63 | self.negative_slope = negative_slope 64 | self.dropout = dropout 65 | self.concat = concat 66 | 67 | self.return_mask = return_mask 68 | 69 | 70 | def forward(self, x, edge_index, dp_mask=None, dp_mask_self=None): 71 | """ 72 | Args: 73 | x (float32 [N(v), in_channels]) : matrix of feature vectors 74 | edge_index (int64 [2, N(v)]) : the list of edges 75 | edge_weight (float32 [N(v)], optional) : the weight on each edge, default: 1 76 | """ 77 | # step 0: split the edge index and convert them to int32 78 | src_index, tar_index = (edge_index[self.sid], edge_index[self.tid]) 79 | 80 | # step 1: Update, forward input features into linear layer 81 | x = self.dense(x) 82 | 83 | # Aggregation stage 84 | return self.propagate(x, src_index, tar_index) 85 | 86 | def propagate(self, feature, src_index, tar_index): 87 | # step 1: get edge weight 88 | e_pre = torch.matmul(feature, self.att) 89 | 90 | e_pre_src, e_pre_tar = torch.chunk(e_pre, chunks=2, dim=1) 91 | e_pre_src_expand = torch.index_select(e_pre_src, dim=0, index=src_index) 92 | e_pre_tar_expand = torch.index_select(e_pre_tar, dim=0, index=tar_index) 93 | e = e_pre_src_expand + e_pre_tar_expand 94 | e_self = e_pre_src + e_pre_tar 95 | e = torch.exp(F.leaky_relu(e, negative_slope=self.negative_slope)) 96 | e_self = torch.exp(F.leaky_relu(e_self, negative_slope=self.negative_slope)) 97 | e_sum = torch_scatter.scatter_add(src=e, index=tar_index, dim=0, dim_size=feature.size(0)) 98 | e_sum += e_self 99 | alpha_self = e_self / e_sum 100 | e_sum = torch.index_select(e_sum, dim=0, index=tar_index) 101 | alpha = e / e_sum 102 | # apply dropout 103 | alpha, mask = my_dropout(alpha, self.dropout, self.training) 104 | alpha_self, mask_self = my_dropout(alpha_self, self.dropout, self.training) 105 | 106 | 107 | # step 1: scatter the feature vectors to the extended feature map of src 108 | src_fm = torch.index_select(feature, dim=0, index=src_index) 109 | 110 | # step 2: scatter the feature vectors to the extended feature map of tar 111 | tar_fm = torch.index_select(feature, dim=0, index=tar_index) 112 | 113 | extended_fm = src_fm * alpha 114 | 115 | # step 3: scatter to the output feature 116 | out = torch_scatter.scatter_add(src=extended_fm, index=tar_index, dim=0, dim_size=feature.size(0)) 117 | 118 | # step 4: apply self-loops 119 | out += feature * alpha_self 120 | 121 | if not self.concat: 122 | out = out.view(-1, self.heads, self.out_channels) 123 | out = out.mean(dim=1).squeeze_() 124 | 125 | if self.return_mask: 126 | return out, mask, mask_self 127 | else: 128 | return out 129 | 130 | 131 | class garGATConv(torch.nn.Module): 132 | def __init__(self, in_channels, out_channels, heads=1, concat=True, 133 | negative_slope=0.2, dropout=0, bias=True, flow='target_to_source', return_mask=False, 134 | cached=False): 135 | """ 136 | Args: 137 | in_channels (int): Size of each input sample 138 | out_channels (int): Size of each output sample. 139 | cached (bool, optional): if set to True, the layer will cache 140 | the computation of D^{-0.5}\hat{A}of D^{-0.5} on the first 141 | execution, and it will be used for further executions. 142 | This is only helpful in transductive learning 143 | bias (bool, optional): If set to True, there will be a bias 144 | flow (str): could be the following two conditions 145 | 'source_to_target': edge_index[0] is the source nodes, [1] is the target nodes 146 | 'target_to_source': edge_index[0] is the target nodes, [0] is the source nodes 147 | fused (bool, optional): If set to True, the gcnlib.gcn_aggregate_f_cuda will be used. Default: False, 148 | verify (bool, optional): If set to True, it will output the difference between fused and unfused version 149 | """ 150 | super(garGATConv, self).__init__() 151 | self.flow = flow 152 | assert self.flow in ['source_to_target', 'target_to_source'] 153 | self.tid, self.sid = (0, 1) if self.flow == 'target_to_source' else (1, 0) 154 | 155 | self.in_channels = in_channels 156 | self.out_channels = out_channels 157 | self.heads = heads 158 | # The dense layer for Update stage. The total number of out_freatures is out_channels * heads 159 | self.dense = torch.nn.Linear(in_features=in_channels, out_features=out_channels * heads, bias=bias) 160 | 161 | # The attention parameters 162 | self.att = Parameter(torch.Tensor(heads * out_channels, 2)) 163 | 164 | glorot(self.dense.weight) 165 | glorot(self.att) 166 | # zeros(self.bias) 167 | 168 | self.negative_slope = negative_slope 169 | self.dropout = dropout 170 | self.concat = concat 171 | 172 | self.cached = cached 173 | self.cached_tar_index_b = None 174 | self.cached_src_index_b = None 175 | self.cached_src_ptr = None 176 | 177 | self.return_mask = return_mask 178 | 179 | 180 | def forward(self, x, edge_index=None, dp_mask=None, dp_mask_self=None, tar_index_b=None, 181 | src_index_b=None, src_ptr=None): 182 | """ 183 | Args: 184 | x (float32 [N(v), in_channels]) : matrix of feature vectors 185 | edge_index (int64 [2, N(v)]) : the list of edges 186 | edge_weight (float32 [N(v)], optional) : the weight on each edge, default: 1 187 | """ 188 | # step 0: split the edge index and convert them to int32 189 | if (not self.cached) or (self.cached_tar_index_b is None): 190 | if src_ptr is None: 191 | src_index, tar_index = (edge_index[self.sid], edge_index[self.tid]) 192 | src_index = src_index.to(torch.int32) 193 | tar_index = tar_index.to(torch.int32) 194 | self.cached_tar_index_b, self.cached_src_index_b, self.cached_src_ptr, dp_mask = coo2csr(tar_index, src_index, 195 | x.size(0), dp_mask, False) 196 | else: 197 | self.cached_tar_index_b = tar_index_b 198 | self.cached_src_index_b = src_index_b 199 | self.cached_src_ptr = src_ptr 200 | 201 | # step 1: Update, forward input features into linear layer 202 | x = self.dense(x) 203 | 204 | # Aggregation stage 205 | return self.propagate(x, dp_mask, dp_mask_self) 206 | 207 | def propagate(self, feature, dp_mask, dp_mask_self): 208 | if dp_mask is not None: 209 | dp_mask = dp_mask.view(-1, 1) 210 | e_pre = torch.matmul(feature, self.att) 211 | 212 | alpha_self, alpha = gat_gar_edge_weight(e_pre, self.cached_src_ptr, 213 | self.cached_tar_index_b, self.cached_src_index_b, 214 | self.negative_slope) 215 | 216 | # dropout on edge weight. 217 | if dp_mask is not None and self.training: # If the mask is provided 218 | alpha = alpha * dp_mask 219 | alpha_self = alpha_self * dp_mask_self 220 | else: # Otherwise 221 | alpha = F.dropout(alpha, self.dropout, self.training) 222 | alpha_self = F.dropout(alpha_self, self.dropout, self.training) 223 | 224 | tar_ptr, src_index_f, alpha_f = csr2csc(self.cached_src_ptr.detach_(), self.cached_tar_index_b.detach_(), 225 | alpha.detach(), feature.size(0)) 226 | 227 | out = fused_gar_agg(feature=feature, src_index=src_index_f, tar_ptr=tar_ptr, 228 | edge_weight_f=alpha_f.detach_(), self_edge_weight=alpha_self, 229 | tar_index=self.cached_tar_index_b, src_ptr=self.cached_src_ptr, edge_weight_b=alpha, 230 | require_edge_weight=True) 231 | 232 | if not self.concat: 233 | out = out.view(-1, self.heads, self.out_channels) 234 | out = out.mean(dim=1).squeeze_() 235 | 236 | 237 | if self.return_mask: 238 | return out, dp_mask, dp_mask_self 239 | else: 240 | return out 241 | 242 | 243 | class gasGATConv(torch.nn.Module): 244 | def __init__(self, in_channels, out_channels, heads=1, concat=True, 245 | negative_slope=0.2, dropout=0, bias=True, flow='target_to_source', return_mask=False, 246 | cached=False): 247 | """ 248 | Args: 249 | in_channels (int): Size of each input sample 250 | out_channels (int): Size of each output sample. 251 | cached (bool, optional): if set to True, the layer will cache 252 | the computation of D^{-0.5}\hat{A}of D^{-0.5} on the first 253 | execution, and it will be used for further executions. 254 | This is only helpful in transductive learning 255 | bias (bool, optional): If set to True, there will be a bias 256 | flow (str): could be the following two conditions 257 | 'source_to_target': edge_index[0] is the source nodes, [1] is the target nodes 258 | 'target_to_source': edge_index[0] is the target nodes, [0] is the source nodes 259 | fused (bool, optional): If set to True, the gcnlib.gcn_aggregate_f_cuda will be used. Default: False, 260 | verify (bool, optional): If set to True, it will output the difference between fused and unfused version 261 | """ 262 | super(gasGATConv, self).__init__() 263 | self.flow = flow 264 | assert self.flow in ['source_to_target', 'target_to_source'] 265 | self.tid, self.sid = (0, 1) if self.flow == 'target_to_source' else (1, 0) 266 | 267 | self.in_channels = in_channels 268 | self.out_channels = out_channels 269 | self.heads = heads 270 | # The dense layer for Update stage. The total number of out_freatures is out_channels * heads 271 | self.dense = torch.nn.Linear(in_features=in_channels, out_features=out_channels * heads, bias=bias) 272 | 273 | # The attention parameters 274 | self.att = Parameter(torch.Tensor(heads * out_channels, 2)) 275 | 276 | glorot(self.dense.weight) 277 | glorot(self.att) 278 | # zeros(self.bias) 279 | 280 | self.negative_slope = negative_slope 281 | self.dropout = dropout 282 | self.concat = concat 283 | self.cached = cached 284 | self.cached_src_index = None 285 | self.cached_tar_index = None 286 | 287 | self.return_mask = return_mask 288 | 289 | 290 | def forward(self, x, edge_index=None, dp_mask=None, dp_mask_self=None, src_index=None, tar_index=None): 291 | """ 292 | Args: 293 | x (float32 [N(v), in_channels]) : matrix of feature vectors 294 | edge_index (int64 [2, N(v)]) : the list of edges 295 | edge_weight (float32 [N(v)], optional) : the weight on each edge, default: 1 296 | """ 297 | # step 0: split the edge index and convert them to int32 298 | if (not self.cached) or (self.cached_src_index is None): 299 | if src_index is None: 300 | src_index, tar_index = (edge_index[self.sid], edge_index[self.tid]) 301 | src_index = src_index.to(torch.int32) 302 | tar_index = tar_index.to(torch.int32) 303 | self.cached_src_index = src_index 304 | self.cached_tar_index = tar_index 305 | 306 | # step 1: Update, forward input features into linear layer 307 | x = self.dense(x) 308 | 309 | # Aggregation stage 310 | return self.propagate(x, dp_mask, dp_mask_self) 311 | 312 | def propagate(self, feature, dp_mask, dp_mask_self): 313 | 314 | e_pre = torch.matmul(feature, self.att) 315 | 316 | alpha_self, alpha = gat_gas_edge_weight(e_pre, self.cached_src_index, self.cached_tar_index, self.negative_slope) 317 | 318 | if dp_mask is not None and self.training: # If the mask is provided 319 | alpha = alpha * dp_mask 320 | alpha_self = alpha_self * dp_mask_self 321 | else: # Otherwise 322 | alpha = F.dropout(alpha, self.dropout, self.training) 323 | alpha_self = F.dropout(alpha_self, self.dropout, self.training) 324 | 325 | out = fused_gas_agg(feature=feature, src_index=self.cached_src_index, tar_index=self.cached_tar_index, 326 | edge_weight=alpha, self_edge_weight=alpha_self, 327 | require_edge_weight=True) 328 | 329 | if not self.concat: 330 | out = out.view(-1, self.heads, self.out_channels) 331 | out = out.mean(dim=1).squeeze_() 332 | 333 | 334 | if self.return_mask: 335 | return out, dp_mask, dp_mask_self 336 | else: 337 | return out 338 | -------------------------------------------------------------------------------- /src/cuda/aggregate_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | 8 | template 9 | __device__ void smem_reduce_v1(volatile scalar_t* sdata, unsigned int tid, unsigned int reduce_len, unsigned int f_dim){ 10 | while (reduce_len > 1){ 11 | __syncthreads(); 12 | // add the remainer 13 | if ((tid < f_dim) && (reduce_len % 2 == 1)){ 14 | sdata[tid] += sdata[tid + f_dim * (reduce_len - 1)]; 15 | } 16 | reduce_len /= 2; 17 | if (tid < f_dim * reduce_len){ 18 | sdata[tid] += sdata[tid + f_dim * reduce_len]; 19 | } 20 | } 21 | } 22 | 23 | 24 | template 25 | __device__ void smem_reduce_v2(volatile scalar_t* sdata, unsigned int tid, unsigned int reduce_len){ 26 | if (reduce_len > 512){ 27 | if(tid < 512){ 28 | sdata[tid] += sdata[tid + 512]; 29 | } 30 | __syncthreads(); 31 | } 32 | if (reduce_len > 256){ 33 | if(tid < 256){ 34 | sdata[tid] += sdata[tid + 256]; 35 | } 36 | __syncthreads(); 37 | } 38 | if (reduce_len > 128){ 39 | if(tid < 128){ 40 | sdata[tid] += sdata[tid + 128]; 41 | } 42 | __syncthreads(); 43 | } 44 | if (reduce_len > 64){ 45 | if(tid < 64){ 46 | sdata[tid] += sdata[tid + 64]; 47 | } 48 | __syncthreads(); 49 | } 50 | if (tid < 32){ 51 | if (reduce_len > 32) sdata[tid] += sdata[tid + 32]; 52 | if (reduce_len > 16) sdata[tid] += sdata[tid + 16]; 53 | if (reduce_len > 8) sdata[tid] += sdata[tid + 8]; 54 | if (reduce_len > 4) sdata[tid] += sdata[tid + 4]; 55 | if (reduce_len > 2) sdata[tid] += sdata[tid + 2]; 56 | if (reduce_len > 1) sdata[tid] += sdata[tid + 1]; 57 | } 58 | } 59 | 60 | 61 | template 62 | __device__ void smem_reduce_v3(volatile scalar_t* sdata, unsigned int tid){ 63 | if (blockSize >= 1024){ 64 | if (tid < 512){ 65 | sdata[tid] += sdata[tid + 512]; 66 | } 67 | __syncthreads(); 68 | } 69 | if (blockSize >= 512){ 70 | if (tid < 256){ 71 | sdata[tid] += sdata[tid + 256]; 72 | } 73 | __syncthreads(); 74 | } 75 | if (blockSize >= 256){ 76 | if (tid < 128){ 77 | sdata[tid] += sdata[tid + 128]; 78 | } 79 | __syncthreads(); 80 | } 81 | if (blockSize >= 128){ 82 | if (tid < 64){ 83 | sdata[tid] += sdata[tid + 64]; 84 | } 85 | __syncthreads(); 86 | } 87 | if (tid < 32){ 88 | if (blockSize >= 64) sdata[tid] += sdata[tid + 32]; 89 | if (blockSize >= 32) sdata[tid] += sdata[tid + 16]; 90 | if (blockSize >= 16) sdata[tid] += sdata[tid + 8]; 91 | if (blockSize >= 8) sdata[tid] += sdata[tid + 4]; 92 | if (blockSize >= 4) sdata[tid] += sdata[tid + 2]; 93 | if (blockSize >= 2) sdata[tid] += sdata[tid + 1]; 94 | } 95 | } 96 | 97 | 98 | /* 99 | * The fused forward kernel for GAR aggregator 100 | */ 101 | 102 | 103 | template 104 | __global__ void fused_gar_f_kernel( 105 | const torch::PackedTensorAccessor __restrict__ feature, 106 | torch::PackedTensorAccessor out, 107 | const int* __restrict__ src_index, 108 | const int* __restrict__ tar_ptr, 109 | const scalar_t* __restrict__ edge_weight, 110 | const scalar_t* __restrict__ self_edge_weight, 111 | const unsigned int f_dim, 112 | const unsigned int stride 113 | ){ 114 | // shared memory for feature reduction 115 | __shared__ scalar_t s_feature[blockSize]; 116 | // Registers 117 | unsigned int tid = threadIdx.x; 118 | unsigned int tar_id = blockIdx.x; // each block handles a single target 119 | unsigned int f_idx = tid % f_dim; 120 | unsigned int e_start = tar_ptr[tar_id]; 121 | unsigned int e_bound = tar_ptr[tar_id + 1]; 122 | 123 | // Step 0: initialize shared memory 124 | s_feature[tid] = 0; 125 | // Step 1: reduce the feature vectors into shared memory 126 | for (unsigned int e_id = e_start + tid / f_dim; e_id < e_bound; e_id += stride){ 127 | s_feature[tid] += feature[src_index[e_id]][f_idx] * edge_weight[e_id]; 128 | } 129 | 130 | // Step 2: Reduction 131 | unsigned int reduce_len = min(stride, e_bound - e_start); 132 | 133 | smem_reduce_v1(s_feature, tid, reduce_len, f_dim); 134 | 135 | // Step 3: Write out 136 | if (tid < f_dim){ 137 | out[tar_id][f_idx] = feature[tar_id][f_idx] * self_edge_weight[tar_id] + s_feature[tid]; 138 | } 139 | } 140 | 141 | 142 | template 143 | __global__ void fused_gar_f_large_kernel( 144 | const torch::PackedTensorAccessor feature, 145 | torch::PackedTensorAccessor out, 146 | const int* __restrict__ src_index, 147 | const int* __restrict__ tar_ptr, 148 | const scalar_t* __restrict__ edge_weight, 149 | const scalar_t* __restrict__ self_edge_weight, 150 | const unsigned int f_dim 151 | ){ 152 | unsigned int tid = threadIdx.x; 153 | unsigned int tar_id = blockIdx.x; 154 | scalar_t self_weight = self_edge_weight[tar_id]; 155 | 156 | for (unsigned int f_idx = tid; f_idx < f_dim; f_idx += blockSize){ 157 | scalar_t s_feature = feature[tar_id][f_idx] * self_weight; 158 | for (unsigned int e_id=tar_ptr[tar_id]; e_id < tar_ptr[tar_id + 1]; e_id ++){ 159 | s_feature += feature[src_index[e_id]][f_idx] * edge_weight[e_id]; 160 | } 161 | out[tar_id][f_idx] = s_feature; 162 | } 163 | } 164 | 165 | #define GAR_THREADS 128 166 | 167 | // CUDA GAR model forward 168 | torch::Tensor fused_gar_f_cuda( 169 | torch::Tensor feature, 170 | torch::Tensor src_index, 171 | torch::Tensor tar_ptr, 172 | torch::Tensor edge_weight, 173 | torch::Tensor self_edge_weight 174 | ){ 175 | unsigned int f_dim = feature.size(1); 176 | unsigned int num_node = feature.size(0); 177 | 178 | auto out = torch::empty_like(feature); 179 | 180 | if (f_dim <= GAR_THREADS){ 181 | unsigned int stride = GAR_THREADS / f_dim; 182 | AT_DISPATCH_FLOATING_TYPES(feature.type(), "aggregation gar forward", ([&]{ 183 | fused_gar_f_kernel<<>>( 184 | feature.packed_accessor(), 185 | out.packed_accessor(), 186 | src_index.data(), 187 | tar_ptr.data(), 188 | edge_weight.data(), 189 | self_edge_weight.data(), f_dim, stride 190 | ); 191 | })); 192 | }else{ 193 | AT_DISPATCH_FLOATING_TYPES(feature.type(), "aggregation gar forward", ([&]{ 194 | fused_gar_f_large_kernel<<>>( 195 | feature.packed_accessor(), 196 | out.packed_accessor(), 197 | src_index.data(), 198 | tar_ptr.data(), 199 | edge_weight.data(), 200 | self_edge_weight.data(), f_dim 201 | ); 202 | })); 203 | } 204 | return out; 205 | } 206 | 207 | 208 | /* 209 | * The fused backward kernel for GAR aggregator 210 | */ 211 | 212 | template 213 | __global__ void fused_gar_b_kernel( 214 | const torch::PackedTensorAccessor __restrict__ grad_out, 215 | torch::PackedTensorAccessor grad_feature, 216 | const int* __restrict__ tar_index, 217 | const int* __restrict__ src_ptr, 218 | const scalar_t* __restrict__ edge_weight, 219 | const scalar_t* __restrict__ self_edge_weight, 220 | unsigned int f_dim, unsigned int stride 221 | ){ 222 | // shared memory for gradient reduction 223 | __shared__ scalar_t s_grad_feature[blockSize]; 224 | unsigned int tid = threadIdx.x; 225 | unsigned int src_id = blockIdx.x; 226 | unsigned int f_idx = tid % f_dim; 227 | unsigned int e_start = src_ptr[src_id]; 228 | unsigned int e_bound = src_ptr[src_id+1]; 229 | 230 | // initialize the shared memory 231 | s_grad_feature[tid] = 0; 232 | 233 | for (unsigned int e_id = e_start + tid / f_dim; e_id < e_bound; e_id += stride){ 234 | s_grad_feature[tid] += grad_out[tar_index[e_id]][f_idx] * edge_weight[e_id]; 235 | } 236 | 237 | // Step 2: reduction 238 | unsigned int reduce_len = min(stride, e_bound - e_start); 239 | 240 | smem_reduce_v1(s_grad_feature, tid, reduce_len, f_dim); 241 | 242 | if (tid < f_dim){ 243 | grad_feature[src_id][f_idx] = grad_out[src_id][f_idx] * self_edge_weight[src_id] + s_grad_feature[tid]; 244 | } 245 | } 246 | 247 | 248 | 249 | template 250 | __global__ void fused_gar_b_kernelv2( 251 | const torch::PackedTensorAccessor __restrict__ grad_out, 252 | const torch::PackedTensorAccessor feature, 253 | torch::PackedTensorAccessor grad_feature, 254 | const int* __restrict__ tar_index, 255 | const int* __restrict__ src_ptr, 256 | scalar_t* grad_edge_weight, 257 | scalar_t* grad_self_edge_weight, 258 | const scalar_t* __restrict__ edge_weight, 259 | const scalar_t* __restrict__ self_edge_weight, 260 | unsigned int f_dim, unsigned int stride 261 | ){ 262 | // shared memory for gradient reduction 263 | __shared__ scalar_t s_grad_feature[blockSize]; 264 | __shared__ scalar_t s_grad_edge_weight[blockSize]; 265 | 266 | unsigned int tid = threadIdx.x; 267 | unsigned int src_id = blockIdx.x; 268 | unsigned int f_idx = tid % f_dim; 269 | unsigned int e_start = src_ptr[src_id]; 270 | unsigned int e_bound = src_ptr[src_id+1]; 271 | unsigned int group_id = tid / f_dim; 272 | unsigned int e_idx = e_start + group_id; 273 | 274 | unsigned int reduce_len = f_dim; 275 | 276 | // initialize the shared memory 277 | s_grad_feature[tid] = 0; 278 | s_grad_edge_weight[tid] = 0; 279 | 280 | scalar_t grad_out_buffer = 0; 281 | scalar_t src_feature_buffer = feature[src_id][f_idx]; 282 | 283 | unsigned int total_strides = (e_bound - e_start + stride - 1) / stride; 284 | 285 | for (unsigned int i=e_start; i < total_strides * stride + e_start; i += stride){ 286 | __syncthreads(); 287 | if ((tid < stride * f_dim) && (e_idx < e_bound)){ 288 | grad_out_buffer = grad_out[tar_index[e_idx]][f_idx]; 289 | s_grad_feature[tid] += grad_out_buffer * edge_weight[e_idx]; 290 | // interleavely load the intermediate results into s_grad_edge_weight 291 | s_grad_edge_weight[group_id + f_idx * stride] = grad_out_buffer * src_feature_buffer; 292 | } 293 | __syncthreads(); 294 | reduce_len = f_dim; 295 | smem_reduce_v1(s_grad_edge_weight, tid, reduce_len, stride); 296 | __syncthreads(); 297 | if ((tid < stride) && (i + tid < e_bound)){ 298 | grad_edge_weight[i+tid] = s_grad_edge_weight[tid]; 299 | } 300 | s_grad_edge_weight[tid] = 0; 301 | e_idx += stride; 302 | } 303 | 304 | // Step 2: reduction 305 | reduce_len = min(stride, e_bound - e_start); 306 | 307 | smem_reduce_v1(s_grad_feature, tid, reduce_len, f_dim); 308 | 309 | if (tid < f_dim){ 310 | grad_out_buffer = grad_out[src_id][f_idx]; 311 | s_grad_edge_weight[tid] = grad_out_buffer * src_feature_buffer; 312 | grad_feature[src_id][f_idx] = grad_out_buffer * self_edge_weight[src_id] + s_grad_feature[tid]; 313 | } 314 | __syncthreads(); 315 | 316 | smem_reduce_v2(s_grad_edge_weight, tid, f_dim); 317 | if (tid == 0){ 318 | grad_self_edge_weight[src_id] = s_grad_edge_weight[0]; 319 | } 320 | 321 | } 322 | 323 | 324 | template 325 | __global__ void fused_gar_b_large_kernel( 326 | const torch::PackedTensorAccessor __restrict__ grad_out, 327 | torch::PackedTensorAccessor grad_feature, 328 | const int* __restrict__ tar_index, 329 | const int* __restrict__ src_ptr, 330 | const scalar_t* __restrict__ edge_weight, 331 | const scalar_t* __restrict__ self_edge_weight, 332 | unsigned int f_dim 333 | ){ 334 | // shared memory for feature reduction 335 | // __shared__ scalar_t s_grad_feature[blockSize]; 336 | unsigned int tid = threadIdx.x; 337 | unsigned int src_id = blockIdx.x; 338 | scalar_t self_weight = self_edge_weight[src_id]; 339 | 340 | for (unsigned int f_idx = tid; f_idx < f_dim; f_idx += blockSize){ 341 | scalar_t s_grad_feature = grad_out[src_id][f_idx] * self_weight; 342 | for (unsigned int e_id=src_ptr[src_id]; e_id < src_ptr[src_id + 1]; e_id ++){ 343 | s_grad_feature += grad_out[tar_index[e_id]][f_idx] * edge_weight[e_id]; 344 | } 345 | grad_feature[src_id][f_idx] = s_grad_feature; 346 | } 347 | } 348 | 349 | 350 | template 351 | __global__ void fused_gar_b_large_kernelv2( 352 | const torch::PackedTensorAccessor __restrict__ grad_out, 353 | const torch::PackedTensorAccessor __restrict__ feature, 354 | torch::PackedTensorAccessor grad_feature, 355 | const int* __restrict__ tar_index, 356 | const int* __restrict__ src_ptr, 357 | scalar_t* grad_edge_weight, 358 | scalar_t* grad_self_edge_weight, 359 | const scalar_t* __restrict__ edge_weight, 360 | const scalar_t* __restrict__ self_edge_weight, 361 | unsigned int f_dim 362 | ){ 363 | // shared memory to buffer the src feature 364 | __shared__ scalar_t s_src_feature[1024]; 365 | __shared__ scalar_t s_grad_feature[1024]; 366 | // shared memory for edge weight reduction 367 | __shared__ scalar_t s_grad_edge_weight[blockSize]; 368 | 369 | unsigned int tid = threadIdx.x; 370 | unsigned int src_id = blockIdx.x; 371 | scalar_t self_weight = self_edge_weight[src_id]; 372 | scalar_t grad_out_buffer = 0; 373 | s_grad_edge_weight[tid] = 0; 374 | 375 | for (unsigned int f_idx=tid; f_idx < f_dim; f_idx += blockSize){ 376 | grad_out_buffer = grad_out[src_id][f_idx]; 377 | s_src_feature[f_idx] = feature[src_id][f_idx]; 378 | s_grad_feature[f_idx] = grad_out_buffer * self_weight; 379 | s_grad_edge_weight[tid] += s_src_feature[f_idx] * grad_out_buffer; 380 | } 381 | __syncthreads(); 382 | smem_reduce_v3(s_grad_edge_weight, tid); 383 | if (tid == 0){ 384 | grad_self_edge_weight[src_id] = s_grad_edge_weight[0]; 385 | } 386 | __syncthreads(); 387 | 388 | 389 | for (unsigned int e_id=src_ptr[src_id]; e_id < src_ptr[src_id + 1]; e_id ++){ 390 | // for each edge 391 | s_grad_edge_weight[tid] = 0; 392 | scalar_t weight = edge_weight[e_id]; 393 | unsigned int tar_id = tar_index[e_id]; 394 | for (unsigned int f_idx=tid; f_idx < f_dim; f_idx += blockSize){ 395 | grad_out_buffer = grad_out[tar_id][f_idx]; 396 | s_grad_feature[f_idx] += grad_out_buffer * weight; 397 | s_grad_edge_weight[tid] += grad_out_buffer * s_src_feature[f_idx]; 398 | } 399 | __syncthreads(); 400 | smem_reduce_v3(s_grad_edge_weight, tid); 401 | __syncthreads(); 402 | if (tid == 0){ 403 | grad_edge_weight[e_id] = s_grad_edge_weight[0]; 404 | } 405 | } 406 | 407 | for (unsigned int f_idx = tid; f_idx < f_dim; f_idx += blockSize){ 408 | grad_feature[src_id][f_idx] = s_grad_feature[f_idx]; 409 | } 410 | } 411 | 412 | 413 | std::vector fused_gar_b_cuda( 414 | torch::Tensor grad_out, 415 | torch::Tensor feature, 416 | torch::Tensor tar_index, 417 | torch::Tensor src_ptr, 418 | torch::Tensor edge_weight, 419 | torch::Tensor self_edge_weight, 420 | bool require_edge_weight 421 | ){ 422 | unsigned int f_dim = grad_out.size(1); 423 | unsigned int num_node = grad_out.size(0); 424 | 425 | auto grad_feature = torch::empty_like(grad_out); 426 | auto grad_edge_weight = torch::empty_like(edge_weight); 427 | auto grad_self_edge_weight = torch::empty_like(self_edge_weight); 428 | 429 | if (f_dim <= GAR_THREADS){ 430 | unsigned int stride = GAR_THREADS / f_dim; 431 | if (require_edge_weight){ 432 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "fused_gar_b_kernel", ([&]{ 433 | fused_gar_b_kernelv2<<>>( 434 | grad_out.packed_accessor(), 435 | feature.packed_accessor(), 436 | grad_feature.packed_accessor(), 437 | tar_index.data(), src_ptr.data(), grad_edge_weight.data(), 438 | grad_self_edge_weight.data(), edge_weight.data(), 439 | self_edge_weight.data(), f_dim, stride 440 | ); 441 | })); 442 | }else{ 443 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "fused_gar_b_larger_kernel", ([&]{ 444 | fused_gar_b_kernel<<>>( 445 | grad_out.packed_accessor(), 446 | grad_feature.packed_accessor(), 447 | tar_index.data(), src_ptr.data(), edge_weight.data(), 448 | self_edge_weight.data(), f_dim, stride 449 | ); 450 | })); 451 | } 452 | }else{ 453 | if (require_edge_weight){ 454 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "aggregation gar forward", ([&]{ 455 | fused_gar_b_large_kernelv2<<>>( 456 | grad_out.packed_accessor(), 457 | feature.packed_accessor(), 458 | grad_feature.packed_accessor(), 459 | tar_index.data(), src_ptr.data(), grad_edge_weight.data(), 460 | grad_self_edge_weight.data(), edge_weight.data(), 461 | self_edge_weight.data(), f_dim 462 | ); 463 | })); 464 | }else{ 465 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "aggregation gar forward", ([&]{ 466 | fused_gar_b_large_kernel<<>>( 467 | grad_out.packed_accessor(), 468 | grad_feature.packed_accessor(), 469 | tar_index.data(), src_ptr.data(), edge_weight.data(), 470 | self_edge_weight.data(), f_dim 471 | ); 472 | })); 473 | } 474 | } 475 | return {grad_feature, grad_edge_weight, grad_self_edge_weight}; 476 | } 477 | 478 | /***********************************************************************************************************/ 479 | 480 | /* 481 | * The fused forward kernel for GAS aggregator 482 | */ 483 | 484 | template 485 | __global__ void fused_gas_f_kernel( 486 | const torch::PackedTensorAccessor feature, 487 | torch::PackedTensorAccessor out, 488 | const int* __restrict__ src_index, 489 | const int* __restrict__ tar_index, 490 | const scalar_t* __restrict__ edge_weight, 491 | const unsigned int edge_stride, // each thread block handles how many edges 492 | const unsigned int f_dim, const unsigned int num_edge 493 | ){ 494 | unsigned int tid = threadIdx.x; 495 | unsigned int bid = blockIdx.x; 496 | if (tid < f_dim * edge_stride){ 497 | unsigned int f_idx = tid % f_dim; 498 | for (unsigned int e_id = bid * edge_stride + tid / f_dim; e_id < num_edge; e_id += gridDim.x * edge_stride){ 499 | atomicAdd(&out[tar_index[e_id]][f_idx], feature[src_index[e_id]][f_idx] * edge_weight[e_id]); 500 | } 501 | } 502 | } 503 | 504 | template 505 | __global__ void fused_gas_f_large_kernel( 506 | const torch::PackedTensorAccessor feature, 507 | torch::PackedTensorAccessor out, 508 | const int* __restrict__ src_index, 509 | const int* __restrict__ tar_index, 510 | const scalar_t* __restrict__ edge_weight, 511 | const unsigned int f_dim, const unsigned int num_edge 512 | ){ 513 | unsigned int e_id = blockIdx.x; 514 | scalar_t weight = edge_weight[e_id]; 515 | for (unsigned int f_idx = threadIdx.x; f_idx < f_dim; f_idx += blockDim.x){ 516 | atomicAdd(&out[tar_index[e_id]][f_idx], feature[src_index[e_id]][f_idx] * weight); 517 | } 518 | } 519 | 520 | template 521 | __global__ void scaled_clone( 522 | const scalar_t* __restrict__ input, 523 | scalar_t* __restrict__ output, 524 | scalar_t* __restrict__ self_edge_weight, 525 | unsigned int f_dim, 526 | unsigned int numel 527 | ){ 528 | unsigned int stride = blockDim.x * gridDim.x; 529 | for (unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel; tid += stride){ 530 | unsigned int node = tid / f_dim; 531 | output[tid] = input[tid] * self_edge_weight[node]; 532 | } 533 | } 534 | 535 | #define BLOCKS(N, T) (N + T - 1)/T 536 | #define GAS_THREADS 256 537 | 538 | torch::Tensor fused_gas_f_cuda( 539 | torch::Tensor feature, 540 | torch::Tensor src_index, 541 | torch::Tensor tar_index, 542 | torch::Tensor edge_weight, 543 | torch::Tensor self_edge_weight 544 | ){ 545 | unsigned int f_dim = feature.size(1); 546 | unsigned int num_edge = edge_weight.size(0); 547 | 548 | auto out = torch::empty_like(feature); 549 | // initialize output with self loop 550 | AT_DISPATCH_FLOATING_TYPES(feature.type(), "scaled clone", ([&]{ 551 | scaled_clone<<>>( 552 | feature.data(), out.data(), self_edge_weight.data(), 553 | f_dim, feature.numel() 554 | ); 555 | })); 556 | 557 | if (f_dim <= GAS_THREADS){ 558 | unsigned int stride = GAS_THREADS/ f_dim; 559 | AT_DISPATCH_FLOATING_TYPES(feature.type(), "fused_gas_f_kernel", ([&]{ 560 | fused_gas_f_kernel<<>>( 561 | feature.packed_accessor(), 562 | out.packed_accessor(), 563 | src_index.data(), tar_index.data(), 564 | edge_weight.data(), stride, f_dim, num_edge 565 | ); 566 | })); 567 | }else{ 568 | AT_DISPATCH_FLOATING_TYPES(feature.type(), "fused_gas_f_kernel_g", ([&]{ 569 | fused_gas_f_large_kernel<<>>( 570 | feature.packed_accessor(), 571 | out.packed_accessor(), 572 | src_index.data(), tar_index.data(), 573 | edge_weight.data(), f_dim, num_edge 574 | ); 575 | })); 576 | } 577 | return out; 578 | } 579 | 580 | 581 | /* 582 | * The fused backward kernel for GAS aggregator 583 | */ 584 | 585 | template 586 | __global__ void fused_gas_b_kernel( 587 | const torch::PackedTensorAccessor grad_out, 588 | torch::PackedTensorAccessor grad_feature, 589 | const int* __restrict__ src_index, 590 | const int* __restrict__ tar_index, 591 | const scalar_t* __restrict__ edge_weight, 592 | const unsigned int edge_stride, // each thread block handles how many edges 593 | const unsigned int f_dim, const unsigned int num_edge 594 | ){ 595 | unsigned int tid = threadIdx.x; 596 | unsigned int bid = blockIdx.x; 597 | 598 | if (tid < f_dim * edge_stride){ 599 | unsigned int f_idx = tid % f_dim; // which entry to handle 600 | for (unsigned int e_id = bid* edge_stride + tid / f_dim; e_id < num_edge; e_id += gridDim.x * edge_stride){ 601 | atomicAdd(&grad_feature[src_index[e_id]][f_idx], grad_out[tar_index[e_id]][f_idx] * edge_weight[e_id]); 602 | } 603 | } 604 | } 605 | 606 | 607 | 608 | template 609 | __global__ void fused_gas_b_kernelv2( 610 | const torch::PackedTensorAccessor grad_out, 611 | const torch::PackedTensorAccessor feature, 612 | torch::PackedTensorAccessor grad_feature, 613 | const int* __restrict__ src_index, 614 | const int* __restrict__ tar_index, 615 | const scalar_t* __restrict__ edge_weight, 616 | scalar_t* grad_edge_weight, 617 | const unsigned int edge_stride, // each thread block handles how many edges 618 | const unsigned int f_dim, const unsigned int num_edge 619 | ){ 620 | __shared__ scalar_t s_grad_edge_weight[blockSize]; 621 | unsigned int tid = threadIdx.x; 622 | unsigned int bid = blockIdx.x; 623 | unsigned int group_id = tid / f_dim; 624 | scalar_t grad_out_buffer = 0; 625 | unsigned int reduce_len = f_dim; 626 | 627 | unsigned int e_start = bid * edge_stride; 628 | unsigned int e_bound = num_edge; 629 | unsigned int f_idx = tid % f_dim; 630 | unsigned int e_idx = e_start + group_id; 631 | unsigned stride = edge_stride * gridDim.x; 632 | 633 | unsigned int total_strides = (e_bound - e_start + stride- 1) / stride; 634 | s_grad_edge_weight[tid] = 0; 635 | 636 | for (unsigned int i=e_start; i < total_strides * stride + e_start; i += stride){ 637 | __syncthreads(); 638 | if ((tid < edge_stride * f_dim) && (e_idx < e_bound)){ 639 | grad_out_buffer = grad_out[tar_index[e_idx]][f_idx]; 640 | atomicAdd(&grad_feature[src_index[e_idx]][f_idx], grad_out_buffer * edge_weight[e_idx]); 641 | s_grad_edge_weight[group_id + f_idx * edge_stride] = grad_out_buffer * feature[src_index[e_idx]][f_idx]; 642 | } 643 | __syncthreads(); 644 | reduce_len = f_dim; 645 | smem_reduce_v1(s_grad_edge_weight, tid, reduce_len, edge_stride); 646 | __syncthreads(); 647 | if ((tid < edge_stride) && (i + tid < e_bound)){ 648 | grad_edge_weight[i + tid] = s_grad_edge_weight[tid]; 649 | } 650 | s_grad_edge_weight[tid] = 0; 651 | e_idx += stride; 652 | } 653 | } 654 | 655 | 656 | template 657 | __global__ void fused_gas_b_large_kernel( 658 | const torch::PackedTensorAccessor grad_out, 659 | torch::PackedTensorAccessor grad_feature, 660 | const int* __restrict__ src_index, 661 | const int* __restrict__ tar_index, 662 | const scalar_t* __restrict__ edge_weight, 663 | const unsigned int f_dim, const unsigned int num_edge 664 | ){ 665 | unsigned int e_id = blockIdx.x; 666 | scalar_t weight = edge_weight[e_id]; 667 | unsigned int tid = threadIdx.x; 668 | 669 | for (unsigned int f_idx = tid; f_idx < f_dim; f_idx += blockDim.x){ 670 | atomicAdd(&grad_feature[src_index[e_id]][f_idx], grad_out[tar_index[e_id]][f_idx] * weight); 671 | } 672 | } 673 | 674 | 675 | template 676 | __global__ void fused_gas_b_large_kernelv2( 677 | const torch::PackedTensorAccessor grad_out, 678 | const torch::PackedTensorAccessor feature, 679 | torch::PackedTensorAccessor grad_feature, 680 | const int* __restrict__ src_index, 681 | const int* __restrict__ tar_index, 682 | const scalar_t* __restrict__ edge_weight, 683 | scalar_t* grad_edge_weight, 684 | const unsigned int f_dim, const unsigned int num_edge 685 | ){ 686 | __shared__ scalar_t s_grad_edge_weight[blockSize]; 687 | 688 | unsigned int e_id = blockIdx.x; 689 | scalar_t weight = edge_weight[e_id]; 690 | unsigned int tid = threadIdx.x; 691 | scalar_t grad_out_buffer = 0; 692 | 693 | s_grad_edge_weight[tid] = 0; 694 | 695 | for (unsigned int f_idx = tid; f_idx < f_dim; f_idx += blockDim.x){ 696 | grad_out_buffer = grad_out[tar_index[e_id]][f_idx]; 697 | s_grad_edge_weight[tid] += grad_out_buffer * feature[src_index[e_id]][f_idx]; 698 | atomicAdd(&grad_feature[src_index[e_id]][f_idx], grad_out_buffer * weight); 699 | } 700 | __syncthreads(); 701 | smem_reduce_v3(s_grad_edge_weight, tid); 702 | if (tid == 0){ 703 | grad_edge_weight[e_id] = s_grad_edge_weight[0]; 704 | } 705 | } 706 | 707 | 708 | template 709 | __global__ void grad_self_loop( 710 | const torch::PackedTensorAccessor grad_out, 711 | const torch::PackedTensorAccessor feature, 712 | torch::PackedTensorAccessor grad_feature, 713 | const scalar_t* __restrict__ self_edge_weight, 714 | scalar_t* grad_self_edge_weight, 715 | const unsigned int f_dim 716 | ){ 717 | unsigned int src_id = blockIdx.x; 718 | unsigned int tid = threadIdx.x; 719 | 720 | __shared__ scalar_t s_grad_edge_weight[blockSize]; 721 | s_grad_edge_weight[tid] = 0; 722 | 723 | if (tid < f_dim){ 724 | scalar_t grad_out_buffer = grad_out[src_id][tid]; 725 | s_grad_edge_weight[tid] = grad_out_buffer * feature[src_id][tid]; 726 | grad_feature[src_id][tid] = grad_out_buffer * self_edge_weight[src_id]; 727 | } 728 | __syncthreads(); 729 | 730 | smem_reduce_v2(s_grad_edge_weight, tid, f_dim); 731 | if (tid == 0){ 732 | grad_self_edge_weight[src_id] = s_grad_edge_weight[0]; 733 | } 734 | } 735 | 736 | 737 | template 738 | __global__ void grad_self_loop_large( 739 | const torch::PackedTensorAccessor grad_out, 740 | const torch::PackedTensorAccessor feature, 741 | torch::PackedTensorAccessor grad_feature, 742 | const scalar_t* __restrict__ self_edge_weight, 743 | scalar_t* grad_self_edge_weight, 744 | const unsigned int f_dim 745 | ){ 746 | unsigned int src_id = blockIdx.x; 747 | unsigned int tid = threadIdx.x; 748 | 749 | __shared__ scalar_t s_grad_edge_weight[blockSize]; 750 | s_grad_edge_weight[tid] = 0; 751 | scalar_t grad_out_buffer = 0; 752 | 753 | for (unsigned int f_idx = tid; f_idx < f_dim; f_idx += blockDim.x){ 754 | grad_out_buffer = grad_out[src_id][f_idx]; 755 | s_grad_edge_weight[tid] += grad_out_buffer * feature[src_id][f_idx]; 756 | grad_feature[src_id][f_idx] = grad_out_buffer * self_edge_weight[src_id]; 757 | } 758 | __syncthreads(); 759 | 760 | smem_reduce_v3(s_grad_edge_weight, tid); 761 | if (tid == 0){ 762 | grad_self_edge_weight[src_id] = s_grad_edge_weight[0]; 763 | } 764 | } 765 | 766 | 767 | std::vector fused_gas_b_cuda( 768 | torch::Tensor grad_out, 769 | torch::Tensor feature, 770 | torch::Tensor src_index, 771 | torch::Tensor tar_index, 772 | torch::Tensor edge_weight, 773 | torch::Tensor self_edge_weight, 774 | bool require_edge_weight 775 | ){ 776 | unsigned int f_dim = grad_out.size(1); 777 | auto grad_feature = torch::empty_like(grad_out); 778 | unsigned int num_edge = edge_weight.size(0); 779 | unsigned int num_node = feature.size(0); 780 | auto grad_edge_weight = torch::empty_like(edge_weight); 781 | auto grad_self_edge_weight = torch::empty_like(self_edge_weight); 782 | 783 | if (require_edge_weight){ 784 | if (f_dim <= GAS_THREADS){ 785 | // gradient from the self loop 786 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "scaled_clone", ([&]{ 787 | grad_self_loop<<>>( 788 | grad_out.packed_accessor(), 789 | feature.packed_accessor(), 790 | grad_feature.packed_accessor(), 791 | self_edge_weight.data(), grad_self_edge_weight.data(), 792 | f_dim 793 | ); 794 | })); 795 | }else{ 796 | // gradient from the self loop 797 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "scaled_clone", ([&]{ 798 | grad_self_loop_large<<>>( 799 | grad_out.packed_accessor(), 800 | feature.packed_accessor(), 801 | grad_feature.packed_accessor(), 802 | self_edge_weight.data(), grad_self_edge_weight.data(), 803 | f_dim 804 | ); 805 | })); 806 | } 807 | }else{ 808 | // gradient from the self loop 809 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "scaled_clone", ([&]{ 810 | scaled_clone<<>>( 811 | grad_out.data(), grad_feature.data(), self_edge_weight.data(), 812 | f_dim, grad_out.numel() 813 | ); 814 | })); 815 | } 816 | 817 | // gradient from the neighbors 818 | // gradient of edge weight 819 | if (f_dim <= GAS_THREADS){ 820 | unsigned int stride = GAS_THREADS / f_dim; 821 | if (require_edge_weight){ 822 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "fused_gas_b_kernelv2", ([&]{ 823 | fused_gas_b_kernelv2<<>>( 824 | grad_out.packed_accessor(), 825 | feature.packed_accessor(), 826 | grad_feature.packed_accessor(), 827 | src_index.data(), tar_index.data(), 828 | edge_weight.data(), grad_edge_weight.data(), stride, f_dim, num_edge 829 | ); 830 | })); 831 | }else{ 832 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "fused_gas_b_kernel", ([&]{ 833 | fused_gas_b_kernel<<>>( 834 | grad_out.packed_accessor(), 835 | grad_feature.packed_accessor(), 836 | src_index.data(), tar_index.data(), 837 | edge_weight.data(), stride, f_dim, num_edge 838 | ); 839 | })); 840 | } 841 | }else{ 842 | if (require_edge_weight){ 843 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "fused_gas_b_large_kernel", ([&]{ 844 | fused_gas_b_large_kernelv2<<>>( 845 | grad_out.packed_accessor(), 846 | feature.packed_accessor(), 847 | grad_feature.packed_accessor(), 848 | src_index.data(), tar_index.data(), 849 | edge_weight.data(), grad_edge_weight.data(), f_dim, num_edge 850 | ); 851 | })); 852 | }else{ 853 | AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "fused_gas_b_large_kernel", ([&]{ 854 | fused_gas_b_large_kernel<<>>( 855 | grad_out.packed_accessor(), 856 | grad_feature.packed_accessor(), 857 | src_index.data(), tar_index.data(), 858 | edge_weight.data(), f_dim, num_edge 859 | ); 860 | })); 861 | } 862 | } 863 | return {grad_feature, grad_edge_weight, grad_self_edge_weight}; 864 | } 865 | --------------------------------------------------------------------------------